A 为N*K 的Tensor B为N*1的索引 将A按照B索引做变换
时间: 2024-11-03 13:23:30 浏览: 21
Pytorch 多维数组运算过程的索引处理方式
如果你有一个N*K的二维张量(Tensor)A和一个N*1的一维索引向量B,你想按照B的索引值改变A的部分元素,你可以使用Python库Pandas或者NumPy的indexing机制。这里是一个简单的例子:
```python
import torch # 假如使用的是PyTorch,如果是其他库例如Pandas,则用相应的API
# 假设 A 是一个 N * K 的 PyTorch tensor
# A = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
# B 是一个 N * 1 的一维 index vector
# B = torch.tensor([0, 1, 2])
# 使用 torch.index_select() 或者 .index_getitem()
selected_values = A[torch.arange(A.shape[0]), B]
# 现在 selected_values 是一个新的 Tensor,形状为 (N, 1),包含了原tensor A 中对应索引B的位置的元素
```
这个操作会返回一个新的张量,其每一行对应于A的某一行,其列则是由索引B指定的A的那一列。
阅读全文