写出基于densenet网络的分类模型代码和训练代码
时间: 2023-05-30 21:01:30 浏览: 160
基于DenseNet的分类模型代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class DenseBlock(nn.Module):
def __init__(self, in_channels, growth_rate):
super(DenseBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False),
nn.BatchNorm2d(4 * growth_rate),
nn.ReLU(inplace=True),
nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
)
def forward(self, x):
out = self.conv_block(x)
out = torch.cat([x, out], 1)
return out
class TransitionBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(TransitionBlock, self).__init__()
self.bn = nn.BatchNorm2d(in_channels)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
out = self.bn(x)
out = F.relu(out)
out = self.conv(out)
out = self.avg_pool(out)
return out
class DenseNet(nn.Module):
def __init__(self, growth_rate=12, block_config=(16, 16, 16), num_classes=10):
super(DenseNet, self).__init__()
# Initial convolution
self.features = nn.Sequential(
nn.Conv2d(3, 2 * growth_rate, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(2 * growth_rate),
nn.ReLU(inplace=True)
)
in_channels = 2 * growth_rate
# Dense blocks
for i, num_layers in enumerate(block_config):
block = nn.Sequential()
for j in range(num_layers):
block.add_module('dense_block_layer_{}'.format(j + 1), DenseBlock(in_channels, growth_rate))
in_channels += growth_rate
self.features.add_module('dense_block_{}'.format(i + 1), block)
if i != len(block_config) - 1:
self.features.add_module('transition_block_{}'.format(i + 1), TransitionBlock(in_channels, in_channels // 2))
in_channels = in_channels // 2
# Final batch norm
self.features.add_module('bn', nn.BatchNorm2d(in_channels))
# Linear layer
self.classifier = nn.Linear(in_channels, num_classes)
def forward(self, x):
out = self.features(x)
out = F.relu(out, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.classifier(out)
return out
```
基于DenseNet的训练代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from densenet import DenseNet
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyperparameters
batch_size = 128
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
num_epochs = 200
# Load data
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Initialize model
model = DenseNet().to(device)
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
# Learning rate schedule
def adjust_learning_rate(optimizer, epoch):
lr = learning_rate
if epoch >= 100:
lr /= 10
if epoch >= 150:
lr /= 10
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Train model
for epoch in range(num_epochs):
model.train()
train_loss = 0
train_correct = 0
total = 0
adjust_learning_rate(optimizer, epoch)
for inputs, targets in train_loader:
inputs = inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
train_correct += predicted.eq(targets).sum().item()
total += targets.size(0)
train_accuracy = 100. * train_correct / total
train_loss /= len(train_loader)
model.eval()
test_loss = 0
test_correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
test_correct += predicted.eq(targets).sum().item()
total += targets.size(0)
test_accuracy = 100. * test_correct / total
test_loss /= len(test_loader)
print('Epoch [{}/{}], Train Loss: {:.4f}, Train Accuracy: {:.2f}%, Test Loss: {:.4f}, Test Accuracy: {:.2f}%'
.format(epoch + 1, num_epochs, train_loss, train_accuracy, test_loss, test_accuracy))
```
阅读全文