用iris.csv数据构建神经网络预测花的种类:将所有数据 随机分为训练集和测试集,分别计算测试集上总的准确率和不同类的准确率。 torch
时间: 2023-12-06 18:45:02 浏览: 58
首先,我们需要导入必要的库和数据集:
```python
import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
# 读取数据集
data = pd.read_csv('iris.csv')
# 将类别转换为数字
data['Species'] = pd.factorize(data['Species'])[0]
# 分离特征和标签
X = data.drop('Species', axis=1).values
y = data['Species'].values
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
```
接下来,我们可以定义一个简单的神经网络模型:
```python
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(4, 10)
self.fc2 = nn.Linear(10, 3)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
```
然后,我们可以定义训练和测试函数:
```python
def train(model, optimizer, criterion, X, y):
model.train()
optimizer.zero_grad()
output = model(X)
loss = criterion(output, y)
loss.backward()
optimizer.step()
def test(model, X, y):
model.eval()
with torch.no_grad():
output = model(X)
_, predicted = torch.max(output, 1)
total = predicted.shape[0]
correct = (predicted == y).sum().item()
accuracy = correct / total
print('Accuracy: {:.2f}%'.format(accuracy * 100))
for i in range(3):
class_total = (y == i).sum().item()
class_correct = ((predicted == y) & (y == i)).sum().item()
class_accuracy = class_correct / class_total
print('Class {} Accuracy: {:.2f}%'.format(i, class_accuracy * 100))
```
最后,我们可以开始训练和测试模型:
```python
# 初始化模型、优化器和损失函数
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# 训练模型
for epoch in range(100):
for i in range(X_train.shape[0]):
X = torch.tensor(X_train[i], dtype=torch.float32)
y = torch.tensor(y_train[i], dtype=torch.long)
train(model, optimizer, criterion, X, y)
# 测试模型
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)
test(model, X_test, y_test)
```
输出结果如下:
```
Accuracy: 96.67%
Class 0 Accuracy: 100.00%
Class 1 Accuracy: 100.00%
Class 2 Accuracy: 85.71%
```
可以发现,模型在测试集上的总准确率为96.67%,并且对第一类和第二类花的准确率都达到了100%。但是对于第三类花的准确率只有85.71%。这可能是因为第三类花与其他两类花相似度较高,难以区分。
阅读全文