torch 选择二维张量第二维存储在另一个张量中的序号对应值
时间: 2024-05-15 18:13:54 浏览: 92
pytorch模型存储的2种实现方法
可以使用 `torch.gather()` 函数来实现这个功能。具体来说,假设有一个二维张量 `A`,它的第二维存储了另一个张量 `B` 中的序号,我们想要得到一个新的张量 `C`,它的值是 `B` 中对应序号的值。
以下是示例代码:
```python
import torch
# 创建示例张量 A 和 B
A = torch.randn(3, 4)
B = torch.tensor([[1, 3, 0, 2], [2, 1, 3, 0], [0, 2, 1, 3]])
# 使用 gather 函数获取对应值
C = torch.gather(A, 1, B)
print("A:")
print(A)
print("B:")
print(B)
print("C:")
print(C)
```
输出结果如下:
```
A:
tensor([[ 0.5794, 1.0303, -1.7010, -0.0530],
[ 0.5791, -1.5785, 0.1576, 1.1485],
[-0.0741, -0.4228, -0.8522, 1.6019]])
B:
tensor([[1, 3, 0, 2],
[2, 1, 3, 0],
[0, 2, 1, 3]])
C:
tensor([[ 1.0303, -0.0530, 0.5794, -1.7010],
[ 0.1576, -1.5785, 1.1485, 0.5791],
[-0.0741, -0.8522, -0.4228, 1.6019]])
```
可以看到,张量 `C` 的每一行对应于 `A` 中的一行,而每一列对应于 `B` 中的一个序号对应的值。
阅读全文