使用AlexNet网络实现手写数字识别
时间: 2024-05-14 16:12:03 浏览: 26
AlexNet是一个深度卷积神经网络,由Alex Krizhevsky、Geoffrey Hinton和Ilya Sutskever在2012年ImageNet大规模视觉识别竞赛中首次提出,可以用于图像分类、目标检测和语义分割等任务。在本文中,我们将介绍如何使用AlexNet网络实现手写数字识别。
1. 数据集准备
我们将使用MNIST手写数字数据集来训练我们的网络。MNIST数据集包含由手写数字组成的图像,每个图像都是28*28像素的灰度图像。数据集共有60000个训练图像和10000个测试图像,每个图像都有一个对应的标签,表示该图像中的数字。
2. 网络结构
AlexNet网络由5个卷积层和3个全连接层组成。每个卷积层后面跟着一个池化层。最后一个全连接层输出数据集中数字的概率分布。
3. 网络训练
我们使用PyTorch框架来实现AlexNet网络。首先,我们需要定义网络结构:
```
import torch.nn as nn
class AlexNet(nn.Module):
def __init__(self, num_classes=10):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
```
接下来,我们需要定义数据加载器和优化器:
```
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlexNet(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
```
最后,我们可以开始训练网络:
```
for epoch in range(10):
train_loss = 0
train_acc = 0
model.train()
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
pred = output.argmax(dim=1, keepdim=True)
train_acc += pred.eq(target.view_as(pred)).sum().item()
train_loss /= len(train_loader.dataset)
train_acc /= len(train_loader.dataset)
test_loss = 0
test_acc = 0
model.eval()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
test_loss += loss.item() * data.size(0)
pred = output.argmax(dim=1, keepdim=True)
test_acc += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_acc /= len(test_loader.dataset)
print('Epoch: {} \tTraining Loss: {:.6f} \tTraining Accuracy: {:.6f} \tTesting Loss: {:.6f} \tTesting Accuracy: {:.6f}'.format(
epoch+1, train_loss, train_acc, test_loss, test_acc))
```
4. 测试网络
训练完成后,我们可以使用测试集来测试网络的性能:
```
model.eval()
with torch.no_grad():
test_loss = 0
test_acc = 0
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() * data.size(0)
pred = output.argmax(dim=1, keepdim=True)
test_acc += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_acc /= len(test_loader.dataset)
print('Test Loss: {:.6f} \tTest Accuracy: {:.6f}'.format(test_loss, test_acc))
```
通过以上步骤,我们就可以使用AlexNet网络实现手写数字识别。
相关推荐
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)