基于transformer对鸢尾花数据分类的代码
时间: 2023-11-04 22:04:12 浏览: 102
以下是基于transformer对鸢尾花数据分类的代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 加载鸢尾花数据
iris = load_iris()
# 数据标准化
scaler = StandardScaler()
features = scaler.fit_transform(iris.data)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(features, iris.target, test_size=0.2)
# 超参数
EPOCHS = 100
BATCH_SIZE = 16
LR = 0.001
NUM_CLASSES = 3
EMBED_DIM = 16
NUM_HEADS = 4
HIDDEN_SIZE = 64
NUM_LAYERS = 2
# 自定义数据集类
class IrisDataset(Dataset):
def __init__(self, data, target):
self.data = data
self.target = target
def __len__(self):
return len(self.target)
def __getitem__(self, index):
x = torch.tensor(self.data[index], dtype=torch.float32)
y = torch.tensor(self.target[index], dtype=torch.long)
return x, y
# 定义Transformer模型
class TransformerModel(nn.Module):
def __init__(self, num_classes, embed_dim, num_heads, hidden_size, num_layers):
super().__init__()
self.embedding = nn.Linear(4, embed_dim)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(embed_dim, num_heads, hidden_size),
num_layers=num_layers
)
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, x):
x = self.embedding(x)
x = x.permute(1, 0, 2)
x = self.transformer(x)
x = x.permute(1, 0, 2)
x = self.fc(x[:, -1, :])
return x
# 创建数据加载器
train_dataset = IrisDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = IrisDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
# 创建模型和优化器
model = TransformerModel(NUM_CLASSES, EMBED_DIM, NUM_HEADS, HIDDEN_SIZE, NUM_LAYERS)
optimizer = optim.Adam(model.parameters(), lr=LR)
# 训练模型
for epoch in range(EPOCHS):
train_loss = 0
train_acc = 0
for x_batch, y_batch in train_loader:
optimizer.zero_grad()
y_pred = model(x_batch)
loss = nn.functional.cross_entropy(y_pred, y_batch)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += (y_pred.argmax(1) == y_batch).sum().item()
train_loss /= len(train_loader)
train_acc /= len(train_dataset)
print(f"Epoch {epoch+1}/{EPOCHS}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}")
# 测试模型
model.eval()
test_acc = 0
with torch.no_grad():
for x_batch, y_batch in test_loader:
y_pred = model(x_batch)
test_acc += (y_pred.argmax(1) == y_batch).sum().item()
test_acc /= len(test_dataset)
print(f"Test_acc: {test_acc:.4f}")
```
该代码使用sklearn中的鸢尾花数据集,使用Transformer模型对数据进行分类。首先对数据进行标准化和划分训练集和测试集,然后定义自定义数据集类和Transformer模型。接着创建数据加载器、模型和优化器,进行训练和测试。最终输出测试集上的准确率。
阅读全文
相关推荐
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)