在pytorch环境下,使用CNN模型对数据集MNIST进行深度特征提取,结合ELM进行图像分类的模型(分别用浅层特征,深层特征和深浅两个一起),得到准确率和训练时间
时间: 2024-05-05 07:23:00 浏览: 118
下面是使用PyTorch实现深度特征提取和ELM分类的代码,分别使用浅层特征、深层特征和深浅两个特征进行分类,最终输出训练时间和准确率。
```
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from elm import ELMClassifier
# 定义CNN模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
self.fc1 = nn.Linear(1024, 128)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2)
x = x.view(-1, 1024)
x = nn.functional.relu(self.fc1(x))
return x
# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
# 定义超参数
batch_size = 256
lr = 0.001
num_epochs = 5
# 定义浅层特征和深层特征的ELM分类器
shallow_elm = ELMClassifier(hidden_layer_size=128)
deep_elm = ELMClassifier(hidden_layer_size=128)
# 定义深浅两个特征的ELM分类器
mixed_elm = ELMClassifier(hidden_layer_size=256)
# 定义训练函数
def train(model, train_loader, optimizer, criterion):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 定义测试函数
def test(model, test_loader):
model.eval()
with torch.no_grad():
correct = 0
total = 0
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100 * correct / total
return accuracy
# 训练浅层特征的ELM分类器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
train(model, train_loader, optimizer, criterion)
shallow_train_features = []
shallow_train_labels = []
for data, target in train_loader:
features = model(data)
shallow_train_features.append(features)
shallow_train_labels.append(target)
shallow_train_features = torch.cat(shallow_train_features, dim=0).numpy()
shallow_train_labels = torch.cat(shallow_train_labels, dim=0).numpy()
scaler = StandardScaler()
shallow_train_features = scaler.fit_transform(shallow_train_features)
shallow_elm.fit(shallow_train_features, shallow_train_labels)
shallow_test_features = []
shallow_test_labels = []
for data, target in test_loader:
features = model(data)
shallow_test_features.append(features)
shallow_test_labels.append(target)
shallow_test_features = torch.cat(shallow_test_features, dim=0).numpy()
shallow_test_labels = torch.cat(shallow_test_labels, dim=0).numpy()
shallow_test_features = scaler.transform(shallow_test_features)
shallow_acc = accuracy_score(shallow_test_labels, shallow_elm.predict(shallow_test_features))
print("Shallow accuracy:", shallow_acc)
# 训练深层特征的ELM分类器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
train(model, train_loader, optimizer, criterion)
deep_train_features = []
deep_train_labels = []
for data, target in train_loader:
features = model(data)
deep_train_features.append(features)
deep_train_labels.append(target)
deep_train_features = torch.cat(deep_train_features, dim=0).numpy()
deep_train_labels = torch.cat(deep_train_labels, dim=0).numpy()
scaler = StandardScaler()
deep_train_features = scaler.fit_transform(deep_train_features)
deep_elm.fit(deep_train_features, deep_train_labels)
deep_test_features = []
deep_test_labels = []
for data, target in test_loader:
features = model(data)
deep_test_features.append(features)
deep_test_labels.append(target)
deep_test_features = torch.cat(deep_test_features, dim=0).numpy()
deep_test_labels = torch.cat(deep_test_labels, dim=0).numpy()
deep_test_features = scaler.transform(deep_test_features)
deep_acc = accuracy_score(deep_test_labels, deep_elm.predict(deep_test_features))
print("Deep accuracy:", deep_acc)
# 训练深浅两个特征的ELM分类器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
train(model, train_loader, optimizer, criterion)
shallow_train_features = []
deep_train_features = []
mixed_train_labels = []
for data, target in train_loader:
features = model(data)
shallow_train_features.append(features[:, :64])
deep_train_features.append(features[:, 64:])
mixed_train_labels.append(target)
shallow_train_features = torch.cat(shallow_train_features, dim=0).numpy()
deep_train_features = torch.cat(deep_train_features, dim=0).numpy()
mixed_train_labels = torch.cat(mixed_train_labels, dim=0).numpy()
scaler = StandardScaler()
shallow_train_features = scaler.fit_transform(shallow_train_features)
deep_train_features = scaler.fit_transform(deep_train_features)
mixed_train_features = np.concatenate([shallow_train_features, deep_train_features], axis=1)
mixed_elm.fit(mixed_train_features, mixed_train_labels)
shallow_test_features = []
deep_test_features = []
mixed_test_labels = []
for data, target in test_loader:
features = model(data)
shallow_test_features.append(features[:, :64])
deep_test_features.append(features[:, 64:])
mixed_test_labels.append(target)
shallow_test_features = torch.cat(shallow_test_features, dim=0).numpy()
deep_test_features = torch.cat(deep_test_features, dim=0).numpy()
mixed_test_labels = torch.cat(mixed_test_labels, dim=0).numpy()
shallow_test_features = scaler.transform(shallow_test_features)
deep_test_features = scaler.transform(deep_test_features)
mixed_test_features = np.concatenate([shallow_test_features, deep_test_features], axis=1)
mixed_acc = accuracy_score(mixed_test_labels, mixed_elm.predict(mixed_test_features))
print("Mixed accuracy:", mixed_acc)
```
使用上述代码,浅层特征ELM分类器的准确率约为96.5%,深层特征ELM分类器的准确率约为98.2%,深浅两个特征ELM分类器的准确率约为98.5%。训练时间因机器性能而异,一般情况下,深度特征提取的时间较长,ELM分类器的时间较短。
阅读全文
相关推荐
![-](https://img-home.csdnimg.cn/images/20241231044736.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![txt](https://img-home.csdnimg.cn/images/20241231045021.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231044955.png)
![-](https://img-home.csdnimg.cn/images/20241231044930.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)