pytorch鸢尾花
时间: 2023-10-29 22:00:40 浏览: 97
PyTorch 是一个基于 Python 的开源机器学习库,它提供了丰富的工具和接口,用于构建深度学习模型。鸢尾花数据集是一个经典的机器学习数据集,其中包含了三个不同种类的鸢尾花(山鸢尾、变色鸢尾和维吉尼亚鸢尾)的测量数据。使用 PyTorch 可以很方便地加载这个数据集并构建分类模型进行训练和预测。
以下是一个简单的使用 PyTorch 构建鸢尾花分类模型的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target
# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 转换为 Tensor
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)
# 定义模型
class IrisClassifier(nn.Module):
def __init__(self):
super(IrisClassifier, self).__init__()
self.fc1 = nn.Linear(4, 16)
self.fc2 = nn.Linear(16, 3)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = IrisClassifier()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
# 在测试集上评估模型
with torch.no_grad():
outputs = model(X_test)
_, predicted = torch.max(outputs.data, 1)
accuracy = (predicted == y_test).sum().item() / len(y_test)
print('Accuracy: {:.2f}%'.format(accuracy * 100))
```
这段代码使用了一个简单的全连接神经网络模型来进行鸢尾花分类任务,通过训练和测试数据集来评估模型的准确性。
阅读全文