基于transformer的鸢尾花数据分类的Python代码
时间: 2024-05-03 08:21:19 浏览: 9
以下是基于transformer的鸢尾花数据分类的Python代码:
```python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
# 加载鸢尾花数据集并进行标准化处理
iris = load_iris()
X = iris.data
y = iris.target
scaler = StandardScaler().fit(X)
X = scaler.transform(X)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 定义鸢尾花数据集类
class IrisDataset(Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return torch.tensor(self.X[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.long)
# 定义transformer模型
class TransformerModel(nn.Module):
def __init__(self, d_model, nhead, num_layers, num_classes):
super().__init__()
self.transformer_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead), num_layers=num_layers)
self.linear = nn.Linear(d_model, num_classes)
def forward(self, x):
x = x.permute(1, 0, 2) # 调整输入张量的维度顺序
x = self.transformer_encoder(x)
x = x[-1, :, :] # 取最后一个时间步的输出作为模型输出
x = self.linear(x)
return x
# 定义超参数
d_model = 4
nhead = 2
num_layers = 2
num_classes = 3
batch_size = 16
lr = 0.001
num_epochs = 50
# 创建数据加载器
train_dataset = IrisDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = IrisDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
# 创建模型、损失函数和优化器
model = TransformerModel(d_model, nhead, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# 训练模型
for epoch in range(num_epochs):
for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 在训练集和测试集上进行测试并输出准确率
with torch.no_grad():
train_correct = 0
train_total = 0
for inputs, labels in train_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
train_total += labels.size(0)
train_correct += (predicted == labels).sum().item()
train_acc = train_correct / train_total
test_correct = 0
test_total = 0
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum().item()
test_acc = test_correct / test_total
print(f'Epoch {epoch + 1}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
```
这段代码使用了PyTorch中的transformer模块来实现鸢尾花数据的分类。首先,加载鸢尾花数据集并进行标准化处理,然后划分训练集和测试集。接着定义了一个IrisDataset类来表示鸢尾花数据集,并且定义了一个TransformerModel类来表示transformer模型。在训练过程中,使用了Adam优化器和交叉熵损失函数。最后,在每个epoch结束时,在训练集和测试集上进行测试并输出准确率。