pytorch去掉tensor维度为1
时间: 2024-04-14 13:31:10 浏览: 270
PyTorch中Tensor的维度变换实现
要去掉PyTorch张量中维度为1的维度,可以使用`squeeze()`函数。`squeeze()`函数会删除所有维度为1的维度。如果你只想删除特定的维度,可以使用`squeeze(dim)`,其中`dim`是要删除的维度的索引。下面是一个例子:
```python
import torch
# 创建一个维度为1的张量
x = torch.randn(3, 1, 2)
print("原始张量:")
print(x)
print("原始张量形状:", x.shape)
# 去掉维度为1的维度
y = x.squeeze()
print("去掉维度为1的维度后的张量:")
print(y)
print("去掉维度为1的维度后的张量形状:", y.shape)
```
输出结果如下:
```
原始张量:
tensor([[[-0.1053, 0.5799]],
[[ 1.5041, -0.5413]],
[[ 0.3897, -0.4527]]])
原始张量形状: torch.Size([3, 1, 2])
去掉维度为1的维度后的张量:
tensor([[-0.1053, 0.5799],
[ 1.5041, -0.5413],
[ 0.3897, -0.4527]])
去掉维度为1的维度后的张量形状: torch.Size([3, 2])
```
如上所示,使用`squeeze()`函数可以去掉维度为1的维度。
阅读全文