pytorch中维度
时间: 2023-11-21 18:56:20 浏览: 72
在PyTorch中,维度用`dim`表示,可以用来指定在哪个维度上进行操作。下面是一些常见的PyTorch维度操作:
1. `torch.unsqueeze(input, dim)`:在指定维度上增加一个维度,返回一个新的张量。例如:
```python
import torch
x = torch.tensor([1, 2, 3]) # 一维张量
x = torch.unsqueeze(x, 0) # 在第0维增加一个维度
print(x) # 输出:tensor([[1, 2, 3]])
```
2. `torch.squeeze(input, dim)`:在指定维度上去掉一个维度,返回一个新的张量。例如:
```python
import torch
x = torch.tensor([[1, 2, 3]]) # 二维张量
x = torch.squeeze(x, 0) # 去掉第0维
print(x) # 输出:tensor([1, 2, 3])
```
3. `torch.transpose(input, dim0, dim1)`:交换两个维度的位置,返回一个新的张量。例如:
```python
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 二维张量
x = torch.transpose(x, 0, 1) # 交换第0维和第1维
print(x) # 输出:tensor([[1, 4],
# [2, 5],
# [3, 6]])
```
4. `torch.cat(inputs, dim)`:在指定维度上将多个张量拼接起来,返回一个新的张量。例如:
```python
import torch
x1 = torch.tensor([[1, 2, 3]])
x2 = torch.tensor([[4, 5, 6]])
x = torch.cat((x1, x2), dim=0) # 在第0维上拼接
print(x) # 输出:tensor([[1, 2, 3],
# [4, 5, 6]])
```
5. `torch.stack(inputs, dim)`:在指定维度上将多个张量堆叠起来,返回一个新的张量。例如:
```python
import torch
x1 = torch.tensor([1, 2, 3])
x2 = torch.tensor([4, 5, 6])
x = torch.stack((x1, x2), dim=0) # 在第0维上堆叠
print(x) # 输出:tensor([[1, 2, 3],
# [4, 5, 6]])
```
阅读全文