将cnn模型进行知识蒸馏的实例pytorch代码
时间: 2024-05-07 18:23:38 浏览: 80
以下是一个简单的示例代码,用于将一个已经训练好的CNN模型进行知识蒸馏:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from distillation import KnowledgeDistiller
# 定义教师和学生模型
class TeacherModel(nn.Module):
def __init__(self):
super(TeacherModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.bn3 = nn.BatchNorm1d(512)
self.relu3 = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.pool2(x)
x = x.view(-1, 64 * 8 * 8)
x = self.fc1(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.fc2(x)
return x
class StudentModel(nn.Module):
def __init__(self):
super(StudentModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(32)
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(32 * 8 * 8, 256)
self.bn3 = nn.BatchNorm1d(256)
self.relu3 = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.pool2(x)
x = x.view(-1, 32 * 8 * 8)
x = self.fc1(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.fc2(x)
return x
# 定义数据加载器和训练函数
train_transforms = transforms.Compose([transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))])
test_transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
def train(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(dataloader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / len(dataloader.dataset)
return epoch_loss
# 定义知识蒸馏对象
teacher_model = TeacherModel()
teacher_model.load_state_dict(torch.load('teacher_model.pth'))
student_model = StudentModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student_model.parameters(), lr=0.1, momentum=0.9)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
kd = KnowledgeDistiller(teacher_model, student_model, criterion, optimizer, device)
# 开始知识蒸馏训练
num_epochs = 10
for epoch in range(num_epochs):
train_loss = train(student_model, train_loader, criterion, optimizer, device)
print('Epoch {}/{}: train_loss = {:.4f}'.format(epoch+1, num_epochs, train_loss))
kd.train(train_loader, epoch)
```
在上面的代码中,我们首先定义了一个带有两个卷积层和两个全连接层的教师模型(`TeacherModel`)和一个带有两个卷积层和两个全连接层的学生模型(`StudentModel`)。这些模型都是用PyTorch的`nn.Module`基类定义的。我们还定义了一个数据加载器和一个训练函数。
然后,我们定义了一个知识蒸馏对象(`KnowledgeDistiller`),该对象使用教师模型来生成软目标,并将其传递给学生模型进行训练。在知识蒸馏训练期间,我们通过调用`kd.train()`函数来更新学生模型的参数。我们还使用`train()`函数来计算学生模型在每个epoch上的训练损失。
请注意,`KnowledgeDistiller`类的实现不在这里提供,但您可以在GitHub上找到大量的开源实现。
阅读全文