A.unsqueeze(2)
时间: 2024-08-15 14:09:41 浏览: 38
pytorch masked_fill报错的解决
`unsqueeze` 是 PyTorch 中的一个功能,它用于在张量的一个维度上添加一个尺寸1,常用于扩充形状以适应神经网络的输入要求。当你调用 `A.unsqueeze(2)` 时,假设 `A` 是一个一维或二维张量(比如一个向量或矩阵),这会创建一个新的张量,其中 `A` 的原始维度保持不变,但在第二维度(索引为2)增加了新的长度1。
举个例子[^4]:
```python
import torch
# 假设 A 是一个一维张量
A = torch.tensor([1, 2, 3])
# 使用 unsqueeze 在第二个维度增加1
B = A.unsqueeze(2)
print(B.shape) # 输出: torch.Size([3, 1])
# 或者 A 是一个二维张量 (行数 x 列数)
A = torch.tensor([[1, 2], [3, 4]])
B = A.unsqueeze(2)
print(B.shape) # 输出: torch.Size([2, 2, 1])
```
这个操作通常在卷积神经网络(CNN)中很有用,当你的输入数据不满足特定卷积层期望的输入维度时,可以使用 `unsqueeze` 来扩展数据。
阅读全文