多维tensor的数组索引该如何指定维度
时间: 2023-11-30 09:04:48 浏览: 31
在使用整数数组对多维tensor进行索引时,我们可以通过在整数数组前面添加一个冒号来指定要进行索引的维度。具体来说,如果要对第1维进行索引,可以将整数数组放在一个列表中,然后在列表前面添加一个冒号,例如:
```python
import torch
# 创建一个3x3x3的tensor
x = torch.tensor([
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
])
# 对第1维进行索引
indices = torch.tensor([0, 2])
y = x[indices, :, :]
print(y)
```
运行结果如下:
```
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[19, 20, 21],
[22, 23, 24],
[25, 26, 27]]])
```
在上面的例子中,我们使用整数数组`[0, 2]`对第1维进行索引,同时使用冒号对后面两维进行全局索引,因此输出结果为:
```
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[19, 20, 21],
[22, 23, 24],
[25, 26, 27]]])
```
需要注意的是,在使用多维数组索引时,我们可以对多个维度进行索引,但是每个维度上的索引数组的长度必须相同,否则会抛出`IndexError`异常。