如何用nn.CrossEntropyLoss做3分类
时间: 2024-03-25 15:39:57 浏览: 65
你可以使用 PyTorch 中的 nn.CrossEntropyLoss 来进行三分类问题的训练。具体来说,你可以使用以下代码:
```python
import torch.nn as nn
# 定义你的模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(in_features=10, out_features=20)
self.fc2 = nn.Linear(in_features=20, out_features=3)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
# 定义损失函数
loss_fn = nn.CrossEntropyLoss()
# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练模型
for epoch in range(num_epochs):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
```
在这个例子中,我们定义了一个包含两个全连接层的模型,并使用 nn.CrossEntropyLoss 作为损失函数。模型的输出是一个大小为 3 的张量,分别表示每个类别的分数。在训练过程中,我们将模型的输出结果和真实的标签传递给损失函数,计算损失并反向传播更新参数。
需要注意的是,nn.CrossEntropyLoss 要求模型的最后一层输出的大小必须等于类别的数量。在本例中,我们有三个类别,因此最后一层的输出大小为 3。
阅读全文