基于pytorch的DenseNet输入图片大小为11*11的分类模型 ,一通道的图像分类代码
时间: 2024-02-09 07:10:59 浏览: 65
以下是基于PyTorch的DenseNet模型,用于11*11黑白图像分类的代码:
``` python
import torch
import torch.nn as nn
import torch.optim as optim
class DenseNet(nn.Module):
def __init__(self):
super(DenseNet, self).__init__()
# Convolutional layer
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False)
# Dense block 1
self.dense1 = nn.Sequential(
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.Conv2d(16, 12, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(12),
nn.ReLU(inplace=True),
nn.Conv2d(12, 12, kernel_size=3, stride=1, padding=1, bias=False)
)
# Transition layer 1
self.trans1 = nn.Sequential(
nn.BatchNorm2d(24),
nn.ReLU(inplace=True),
nn.Conv2d(24, 6, kernel_size=1, stride=1, bias=False),
nn.AvgPool2d(kernel_size=2, stride=2)
)
# Dense block 2
self.dense2 = nn.Sequential(
nn.BatchNorm2d(6),
nn.ReLU(inplace=True),
nn.Conv2d(6, 12, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(12),
nn.ReLU(inplace=True),
nn.Conv2d(12, 12, kernel_size=3, stride=1, padding=1, bias=False)
)
# Transition layer 2
self.trans2 = nn.Sequential(
nn.BatchNorm2d(18),
nn.ReLU(inplace=True),
nn.Conv2d(18, 6, kernel_size=1, stride=1, bias=False),
nn.AvgPool2d(kernel_size=2, stride=2)
)
# Dense block 3
self.dense3 = nn.Sequential(
nn.BatchNorm2d(6),
nn.ReLU(inplace=True),
nn.Conv2d(6, 12, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(12),
nn.ReLU(inplace=True),
nn.Conv2d(12, 12, kernel_size=3, stride=1, padding=1, bias=False)
)
# Global average pooling
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
# Fully connected layer
self.fc = nn.Linear(12, 10)
def forward(self, x):
# Convolutional layer
x = self.conv1(x)
# Dense block 1
x1 = self.dense1(x)
x = torch.cat([x, x1], dim=1)
# Transition layer 1
x = self.trans1(x)
# Dense block 2
x2 = self.dense2(x)
x = torch.cat([x, x2], dim=1)
# Transition layer 2
x = self.trans2(x)
# Dense block 3
x3 = self.dense3(x)
x = torch.cat([x, x3], dim=1)
# Global average pooling
x = self.avgpool(x)
x = x.view(x.size(0), -1)
# Fully connected layer
x = self.fc(x)
return x
# Define the model
model = DenseNet()
# Define the loss function
criterion = nn.CrossEntropyLoss()
# Define the optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
# Train the model
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# Zero the parameter gradients
optimizer.zero_grad()
# Forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Print statistics
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
```
其中,`trainloader`是数据集的dataloader。在训练之前,需要将数据集转换为tensor并标准化。此外,还需要对模型进行调参以取得更好的效果。
阅读全文