pytorch softmax多分类
时间: 2023-11-20 14:57:03 浏览: 37
以下是使用PyTorch实现Softmax多分类的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
class SoftmaxClassifier(nn.Module):
def __init__(self, input_dim, output_dim):
super(SoftmaxClassifier, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
out = self.linear(x)
return out
# 定义超参数
input_dim = 4
output_dim = 3
lr = 0.01
epochs = 100
# 准备数据
X = torch.tensor([[6.4, 2.8, 5.6, 2.2], [5.0, 2.3, 3.3, 1.0], [4.9, 3.1, 1.5, 0.1], [6.0, 3.0, 4.8, 1.8]], dtype=torch.float32)
Y = torch.tensor([2, 1, 0, 1], dtype=torch.long)
# 初始化模型和损失函数
model = SoftmaxClassifier(input_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)
# 训练模型
for epoch in range(epochs):
optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, Y)
loss.backward()
optimizer.step()
if (epoch+1) % 10 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
# 测试模型
with torch.no_grad():
outputs = model(X)
_, predicted = torch.max(outputs.data, 1)
print('Predicted:', predicted)
```
在这个示例中,我们定义了一个Softmax分类器模型,使用交叉熵损失函数和随机梯度下降优化器进行训练。我们使用鸢尾花数据集作为示例数据集,其中有4个特征和3个类别。在训练过程中,我们输出每个epoch的损失值,并在训练结束后使用训练好的模型进行预测。