请帮我写出基于pytorch resnet实现图片2分类的代码
时间: 2023-11-03 18:04:15 浏览: 86
基于Pytorch TorchHub和RESNET的图像分类案例代码
5星 · 资源好评率100%
好的,以下是基于 PyTorch 实现图片二分类任务的 ResNet 代码示例:
首先,导入所需的包:
```python
import torch
import torch.nn as nn
import torchvision.models as models
```
接着,定义 ResNet 模型:
```python
class ResNetModel(nn.Module):
def __init__(self, num_classes):
super(ResNetModel, self).__init__()
self.resnet = models.resnet50(pretrained=True)
num_feats = self.resnet.fc.in_features
self.resnet.fc = nn.Linear(num_feats, num_classes)
def forward(self, x):
return self.resnet(x)
```
在这个模型中,我们使用了预训练的 ResNet-50 作为特征提取器,并在其顶部添加了一个全连接层进行分类。`num_classes` 参数表示数据集的类别数。
接着,定义损失函数和优化器:
```python
model = ResNetModel(num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
```
这里使用了交叉熵损失函数和 Adam 优化器。
接下来,定义训练函数:
```python
def train(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloader.dataset)
epoch_acc = running_corrects.double() / len(dataloader.dataset)
return epoch_loss, epoch_acc
```
这个函数用于训练模型。其中,`dataloader` 表示数据集的数据加载器,`device` 表示使用的设备(如 CPU 或 GPU)。
最后,定义测试函数:
```python
def test(model, dataloader, criterion, device):
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(False):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloader.dataset)
epoch_acc = running_corrects.double() / len(dataloader.dataset)
return epoch_loss, epoch_acc
```
这个函数用于测试模型在测试集上的性能。
完整的训练流程如下:
```python
num_epochs = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# dataloaders
train_dataloader = ...
val_dataloader = ...
# model, loss function, optimizer
model = ResNetModel(num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# train and evaluate
for epoch in range(num_epochs):
train_loss, train_acc = train(model, train_dataloader, criterion, optimizer, device)
val_loss, val_acc = test(model, val_dataloader, criterion, device)
print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f} - Train Acc: {train_acc:.4f} - Val Loss: {val_loss:.4f} - Val Acc: {val_acc:.4f}")
```
其中,`train_dataloader` 和 `val_dataloader` 分别表示训练集和验证集的数据加载器。
希望这个代码示例可以帮到你!
阅读全文