torch.reshape()函数的用法
时间: 2023-09-13 11:10:24 浏览: 675
`torch.reshape()`函数是PyTorch中的一个张量操作函数,用于改变张量的形状。
它的语法如下:
```python
torch.reshape(input, shape)
```
参数说明:
- `input`:要改变形状的输入张量。
- `shape`:一个表示目标形状的元组或列表。
函数返回一个具有指定形状的新张量,该新张量与原始张量具有相同的数据。
以下是一些示例:
```python
import torch
# 改变张量的形状为(2, 3)
x = torch.tensor([1, 2, 3, 4, 5, 6])
reshaped_tensor = torch.reshape(x, (2, 3))
print(reshaped_tensor)
# 输出: tensor([[1, 2, 3],
# [4, 5, 6]])
# 改变张量的形状为(3, 2)
x = torch.tensor([1, 2, 3, 4, 5, 6])
reshaped_tensor = torch.reshape(x, (3, 2))
print(reshaped_tensor)
# 输出: tensor([[1, 2],
# [3, 4],
# [5, 6]])
# 改变张量的形状为(2, 2, 2)
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
reshaped_tensor = torch.reshape(x, (2, 2, 2))
print(reshaped_tensor)
# 输出: tensor([[[1, 2],
# [3, 4]],
#
# [[5, 6],
# [7, 8]]])
```
这些示例展示了`torch.reshape()`函数的基本用法,你可以根据自己的需求使用它来改变张量的形状。注意,新形状的元素数量必须与原始张量的元素数量相同,否则会引发错误。
阅读全文