torch.unsqu
时间: 2024-01-11 09:22:29 浏览: 95
torch.unsqueeze()是PyTorch中的一个函数,用于在指定维度上对张量进行扩展。它的作用是在指定的维度上增加一个维度,扩展张量的形状。
函数的语法如下:
torch.unsqueeze(input, dim)
其中,input是要进行扩展的张量,dim是要在哪个维度上进行扩展。
下面是一个示例代码,演示了如何使用torch.unsqueeze()函数:
```python
import torch
# 创建一个2维张量
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 对第0维进行扩展
y = torch.unsqueeze(x, 0)
print(y)
# 输出结果为:
# tensor([[[1, 2, 3],
# [4, 5, 6]]])
# 对第1维进行扩展
z = torch.unsqueeze(x, 1)
print(z)
# 输出结果为:
# tensor([[[1, 2, 3]],
# [[4, 5, 6]]])
```
在上面的示例中,我们首先创建了一个2维张量x。然后使用torch.unsqueeze()函数对x进行扩展。在第一个例子中,我们对第0维进行扩展,结果是一个3维张量。在第二个例子中,我们对第1维进行扩展,结果也是一个3维张量。
这样,我们就可以通过torch.unsqueeze()函数在指定维度上对张量进行扩展了。
阅读全文