transformer模型对鸢尾花数据的分类
时间: 2023-08-23 10:04:36 浏览: 91
Transformer模型可以用于鸢尾花数据的分类,以下是一个样例代码:
```
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 加载鸢尾花数据
iris = load_iris()
X, y = iris.data, iris.target
# 数据标准化
scaler = StandardScaler()
X = scaler.fit_transform(X)
# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 超参数
num_classes = 3
num_epochs = 100
batch_size = 16
input_dim = X.shape[1]
hidden_dim = 64
num_layers = 6
dropout = 0.5
# 定义Transformer模型
class TransformerModel(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, num_classes, dropout):
super(TransformerModel, self).__init__()
self.encoder_layer = nn.TransformerEncoderLayer(input_dim, nhead=8, dim_feedforward=hidden_dim, dropout=dropout)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(input_dim, num_classes)
def forward(self, x):
x = self.transformer_encoder(x)
x = self.fc(x)
return x
# 加载数据
train_dataset = torch.utils.data.TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long))
test_dataset = torch.utils.data.TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 定义模型、损失函数和优化器
model = TransformerModel(input_dim, hidden_dim, num_layers, num_classes, dropout)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练模型
for epoch in range(num_epochs):
for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs.transpose(0,1))
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (i+1) % 5 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
# 测试模型
model.eval()
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in test_loader:
outputs = model(inputs.transpose(0,1))
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the {} test samples: {} %'.format(len(test_loader)*batch_size, 100 * correct / total))
```
这个代码中,我们首先加载鸢尾花数据,然后将其标准化,并将其划分为训练集和测试集。我们定义了Transformer模型,并使用PyTorch中的DataLoader加载数据。我们使用交叉熵损失函数和Adam优化器来训练模型。在测试阶段,我们使用模型对测试集进行预测,并计算准确度。
注意,在这个例子中,我们使用了一个较小的Transformer模型,只有6层,因为鸢尾花数据集非常小。在实际应用中,可能需要更大的模型以获得更好的性能。
阅读全文