鸢尾花分类预测 pytorch
时间: 2024-10-08 17:18:03 浏览: 42
在PyTorch中,鸢尾花分类预测通常是一个用于入门深度学习的经典案例,特别是对于初学者来说,因为它数据集较小,模型结构简单。鸢尾花数据集(Iris dataset)包含三种不同种类的鸢尾花特征,如花瓣长度、宽度等,目标是根据这些特征预测其所属的鸢尾花类别。
首先,你需要加载这个数据集,可以使用scikit-learn库提供。然后,将数据预处理成PyTorch所需的张量格式。接着,你可以构建一个基础的神经网络模型,比如全连接层或多层感知机(MLP),作为分类器。训练过程包括定义损失函数(如交叉熵)、优化器(如SGD或Adam)以及迭代训练和验证。
以下是一个简单的示例:
```python
import torch
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
# 加载数据并划分训练集和测试集
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
# 将数据转化为PyTorch张量
X_train_tensor = torch.tensor(X_train, dtype=torch.float)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)
# 创建数据加载器
train_data = TensorDataset(X_train_tensor, y_train_tensor)
test_data = TensorDataset(X_test_tensor, y_test_tensor)
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=len(X_test))
# 构建模型,这里假设是两个隐藏层的MLP
model = torch.nn.Sequential(
torch.nn.Linear(4, 6), # 输入维度4到隐层维度6
torch.nn.ReLU(),
torch.nn.Linear(6, 3), # 隐层维度到类别数3
torch.nn.Softmax(dim=1) # 输出层softmax
)
# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
num_epochs = 50
for epoch in range(num_epochs):
for inputs, labels in train_loader:
optimizer.zero_grad() # 清零梯度
outputs = model(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新权重
# 测试模型性能
model.eval()
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test accuracy: {accuracy}%")
```
阅读全文