利用IRIS数据集对鸢尾花数据集多分类任务pytorch+多层感知机代码及实现
时间: 2024-10-12 15:08:52 浏览: 65
在Python中,我们可以利用PyTorch库以及Iris数据集来进行多分类任务,如鸢尾花识别。这里是一个简单的例子,展示如何使用多层感知机(MLP)模型进行训练和预测:
```python
# 导入所需库
import torch
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
# 加载Iris数据集
iris = load_iris()
X = iris.data
y = iris.target
# 数据预处理
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
train_data = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
test_data = TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))
# 定义多层感知机模型
class IrisClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes):
super(IrisClassifier, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.dropout(out)
out = self.fc2(out)
return out
input_dim = X_scaled.shape[1] # 输入维度
hidden_dim = 64 # 隐藏层神经元数量
num_classes = len(set(y)) # 类别数
model = IrisClassifier(input_dim, hidden_dim, num_classes)
# 设置优化器和损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# 训练模型
epochs = 100
for epoch in range(epochs):
for inputs, labels in DataLoader(train_data, batch_size=16, shuffle=True):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 测试模型
model.eval() # 将模式设置为评估
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in DataLoader(test_data):
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy}%")
阅读全文
相关推荐
![py](https://img-home.csdnimg.cn/images/20250102104920.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231044955.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231044955.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)