ResNet架构的CNN识别图片代码
时间: 2024-05-01 20:16:01 浏览: 158
cnn图像识别代码
以下是使用PyTorch实现的ResNet架构的CNN识别图片的代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 定义ResNet网络模型
class ResNet(nn.Module):
def __init__(self, in_channels, num_classes):
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(64, 64, 3)
self.layer2 = self._make_layer(64, 128, 4, stride=2)
self.layer3 = self._make_layer(128, 256, 6, stride=2)
self.layer4 = self._make_layer(256, 512, 3, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, in_channels, out_channels, num_blocks, stride=1):
layers = []
layers.append(Block(in_channels, out_channels, stride))
for i in range(num_blocks-1):
layers.append(Block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc(x)
return x
class Block(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(Block, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = self.shortcut(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x += identity
x = self.relu(x)
return x
# 加载数据集
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.CIFAR10(root='dataset/', train=True, transform=train_transforms, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_dataset = datasets.CIFAR10(root='dataset/', train=False, transform=test_transforms, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)
# 定义优化器和损失函数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNet(in_channels=3, num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
for batch_idx, (data, targets) in enumerate(train_loader):
data = data.to(device)
targets = targets.to(device)
# 前向传播
scores = model(data)
loss = criterion(scores, targets)
# 反向传播和更新参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 每100个batch打印一次信息
if batch_idx % 100 == 0:
print(f'Epoch [{epoch}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')
# 测试模型
def check_accuracy(loader, model):
num_correct = 0
num_samples = 0
model.eval()
with torch.no_grad():
for data, targets in loader:
data = data.to(device)
targets = targets.to(device)
scores = model(data)
_, predictions = scores.max(1)
num_correct += (predictions == targets).sum()
num_samples += predictions.size(0)
model.train()
return num_correct / num_samples
print(f'Train Accuracy: {check_accuracy(train_loader, model)*100:.2f}%')
print(f'Test Accuracy: {check_accuracy(test_loader, model)*100:.2f}%')
```
上述代码实现了对CIFAR-10数据集的分类任务。在ResNet的基础上,使用SGD优化器和交叉熵损失函数进行训练,并在每个epoch结束后计算训练集和测试集上的准确率。
阅读全文