Tensor的索引、切片、变换、拼接和拆分操作,举例
时间: 2023-05-30 07:02:59 浏览: 101
Tensorflow进行多维矩阵的拆分与拼接实例
索引操作:
假设我们有一个3维Tensor,形状为(2,3,4),可以使用下面的代码进行索引:
```
import torch
# 创建一个3维Tensor
x = torch.randn(2,3,4)
# 获取第一个元素
print(x[0][0][0])
# 获取第一维的所有元素,即x[0]和x[1]
print(x[:])
# 获取第一维的第二个元素,即x[1]
print(x[1])
# 获取第一维的前两个元素,即x[0]和x[1]
print(x[:2])
# 获取第二维和第三维的前两个元素,即x[:,:,0:2]
print(x[:,:,0:2])
```
切片操作:
可以使用切片操作获取Tensor的子集,下面是一些例子:
```
import torch
# 创建一个2维Tensor
x = torch.randn(3,3)
# 获取第一行和第二行,即x[0:2,:]
print(x[0:2,:])
# 获取第二列和第三列,即x[:,1:3]
print(x[:,1:3])
# 获取所有行和第一列,即x[:,0]
print(x[:,0])
# 获取所有行和第一列和第三列,即x[:,[0,2]]
print(x[:,[0,2]])
```
变换操作:
可以使用变换操作改变Tensor的形状,下面是一些例子:
```
import torch
# 创建一个2维Tensor
x = torch.randn(3,4)
# 将Tensor变换为4行3列的Tensor
print(x.view(4,3))
# 将Tensor变换为1行12列的Tensor
print(x.view(1,12))
# 将Tensor变换为12行1列的Tensor
print(x.view(12,1))
# 将Tensor变换为3维Tensor,其中第一维为2,第二维为2,第三维为3
print(x.view(2,2,3))
```
拼接操作:
可以使用拼接操作将多个Tensor合并成一个Tensor,下面是一些例子:
```
import torch
# 创建两个2维Tensor
x = torch.randn(2,3)
y = torch.randn(2,4)
# 将两个Tensor沿着第二维进行拼接,即变为2行7列的Tensor
print(torch.cat([x,y], dim=1))
# 创建两个3维Tensor
x = torch.randn(2,3,4)
y = torch.randn(2,3,5)
# 将两个Tensor沿着第三维进行拼接,即变为2行3列9层的Tensor
print(torch.cat([x,y], dim=2))
```
拆分操作:
可以使用拆分操作将一个Tensor拆分成多个Tensor,下面是一些例子:
```
import torch
# 创建一个2维Tensor
x = torch.randn(2,6)
# 将Tensor沿着第二维拆分成3个Tensor,每个Tensor有2行2列
print(torch.split(x, [2,2,2], dim=1))
# 创建一个3维Tensor
x = torch.randn(2,3,6)
# 将Tensor沿着第三维拆分成3个Tensor,每个Tensor有2行3列2层
print(torch.split(x, [2,2,2], dim=2))
```
阅读全文