torch.unique怎么使用
时间: 2024-02-17 15:04:32 浏览: 143
python torch.utils.data.DataLoader使用方法
5星 · 资源好评率100%
`torch.unique(input, sorted=False, return_inverse=False, dim=None)` 是 PyTorch 中的一个函数,它可以返回一个张量中的不同元素,并按照它们在原始张量中出现的顺序进行排序。
参数说明:
- `input`:输入张量。
- `sorted`:是否对结果进行排序,默认为 `False`。
- `return_inverse`:是否同时返回一个索引张量,使得 `unique(input)[0][return_inverse]` 等价于 `input`,默认为 `False`。
- `dim`:沿着哪个维度进行去重,默认为 `None`,表示将整个张量视为一个一维向量。
使用示例:
```python
import torch
# 一维张量去重
x = torch.tensor([1, 3, 2, 1, 2, 4])
y = torch.unique(x)
print(y) # 输出 tensor([1, 2, 3, 4])
# 二维张量按行去重
x = torch.tensor([[1, 3], [2, 1], [2, 4], [1, 3]])
y = torch.unique(x, dim=0)
print(y) # 输出 tensor([[1, 3], [2, 1], [2, 4]])
# 二维张量按列去重
x = torch.tensor([[1, 3], [2, 1], [2, 4], [1, 3]])
y = torch.unique(x, dim=1)
print(y) # 输出 tensor([[1, 3], [2, 1], [4, 2], [3, 1]])
```
如果设置了 `sorted=True`,则返回的结果张量会按照升序排列。如果设置了 `return_inverse=True`,则还会返回一个索引张量,可以通过 `torch.index_select(input, 0, inverse_indices)` 来恢复原始张量 `input`。
阅读全文