用pytorch写一个非图片分类的的多分类,直接可以运行的完整代码
时间: 2024-05-02 19:18:28 浏览: 14
以下是一个使用PyTorch实现的非图片分类的多分类任务的完整代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
# 定义数据集类
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
x, y = self.data[index]
return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)
def __len__(self):
return len(self.data)
# 定义模型类
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
# 定义超参数
input_size = 4
hidden_size = 8
num_classes = 3
learning_rate = 0.01
num_epochs = 100
# 定义数据
data = [
([5.1, 3.5, 1.4, 0.2], 0),
([4.9, 3.0, 1.4, 0.2], 0),
([4.7, 3.2, 1.3, 0.2], 0),
([7.0, 3.2, 4.7, 1.4], 1),
([6.4, 3.2, 4.5, 1.5], 1),
([6.9, 3.1, 4.9, 1.5], 1),
([6.1, 2.8, 4.0, 1.3], 1),
([6.3, 3.3, 6.0, 2.5], 2),
([5.8, 2.7, 5.1, 1.9], 2),
([7.1, 3.0, 5.9, 2.1], 2)
]
# 划分训练集和测试集
train_data = data[:7]
test_data = data[7:]
# 定义数据加载器
train_dataset = MyDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
# 定义模型和优化器
model = MyModel(input_size, hidden_size, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(num_epochs):
for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (i+1) % 2 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
# 在测试集上验证模型的准确率
test_dataset = MyDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=1)
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += 1
correct += (predicted == labels).sum().item()
print('Test Accuracy: {:.2f}%'.format(correct / total * 100))
```
该代码实现了一个使用PyTorch实现的非图片分类的多分类任务,其中使用了一个简单的神经网络模型和随机梯度下降优化器。在训练过程中,将训练集划分为批次进行训练,并输出每个批次的损失值。在测试集上验证模型的准确率。
相关推荐
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)