pytorch 将整型标签转为onehot
时间: 2023-09-09 12:01:30 浏览: 180
在PyTorch中,我们可以使用torch.nn.functional.one_hot函数将整型标签转换为one-hot编码。此函数的输入是一个整数张量和一个表示类别总数的标量张量。该函数会返回一个表示每个类别的one-hot编码的张量。
下面是一个使用示例:
```python
import torch
import torch.nn.functional as F
# 创建一个包含4个类别的整型标签
labels = torch.tensor([0, 2, 1, 3])
# 计算类别总数
num_classes = labels.max().item() + 1
# 将整型标签转换为one-hot编码
onehot_labels = F.one_hot(labels, num_classes)
print(onehot_labels)
```
输出结果为:
```
tensor([[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]], dtype=torch.float32)
```
在这个示例中,我们有4个整型标签[0, 2, 1, 3],类别总数为4。使用F.one_hot函数将整型标签转换为one-hot编码后,得到一个形状为(4, 4)的张量,其中每一行表示一个标签的one-hot编码。
需要注意的是,F.one_hot函数返回的张量默认为浮点型。如果需要整型类型的one-hot编码,可以使用.to(dtype=torch.int)方法将张量的数据类型转换为整型。
阅读全文