pytorch one-hot向量
时间: 2023-09-05 10:03:53 浏览: 176
PyTorch中的one-hot向量是使用独热编码(one-hot encoding)方法来表示离散变量的一种方式。独热编码是将离散变量的每个取值表示为一个长度为变量总数的二进制向量,只有对应的取值位置为1,其余位置都为0。
在PyTorch中,可以使用torch.nn.functional.one_hot函数来实现对离散变量进行独热编码。该函数接受一个整数张量和类别的总数作为输入,返回一个与输入相同形状的张量,其中每个元素都是对应类别的独热编码。
例如,假设有一个大小为(3,)的整数张量x,表示三个样本的类别,类别总数为5。可以使用torch.nn.functional.one_hot函数对x进行独热编码如下:
```python
import torch
import torch.nn.functional as F
x = torch.tensor([2, 4, 1])
one_hot_x = F.one_hot(x, num_classes=5)
print(one_hot_x)
```
运行上述代码会得到一个大小为(3, 5)的张量one_hot_x,表示三个样本在五个类别中的独热编码。对应的独热编码如下:
```
tensor([[0, 0, 1, 0, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0]])
```
可以看到,第一个样本对应的类别是2,独热编码的结果是[0, 0, 1, 0, 0]。第二个样本对应的类别是4,独热编码的结果是[0, 0, 0, 0, 1]。第三个样本对应的类别是1,独热编码的结果是[0, 1, 0, 0, 0]。
阅读全文