pytorch从一个有重复值的tensor中随机取5个下标,且取的下标在原数据上对应的数值不重复
时间: 2024-04-09 21:28:24 浏览: 90
pytorch实现保证每次运行使用的随机数都相同
要从一个有重复值的tensor中随机取5个下标,且取的下标在原数据上对应的数值不重复,你可以使用`torch.unique()`函数来获取tensor中的唯一值,并结合`torch.randperm()`函数来生成随机下标。然后,你可以遍历这些随机下标,检查对应的数值是否已经被选择过,如果已经选择过则进行下一个随机下标的选择。下面是一个示例代码:
```python
import torch
# 原始的tensor
tensor = torch.tensor([1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 10])
# 获取tensor中的唯一值
unique_values = torch.unique(tensor)
# 生成随机下标
indices = []
selected_values = []
while len(indices) < 5:
index = torch.randint(0, unique_values.size(0), (1,))
value = unique_values[index]
if value not in selected_values:
indices.append(index)
selected_values.append(value)
indices = torch.cat(indices)
selected_values = torch.cat(selected_values)
print(indices)
print(selected_values)
```
这段代码会输出随机生成的下标以及对应的数值。由于需要检查每次选择的数值是否已经被选择过,因此可能需要多次循环才能找到5个不重复的下标。注意,由于是随机生成的,每次运行结果可能会不同。
阅读全文