ResNet18实现图像分类的实例代码
时间: 2024-11-22 22:06:39 浏览: 21
基于python inference 、ResNet实现的图像分类
ResNet18在PyTorch中实现图像分类的基本步骤如下:
```python
import torch
from torchvision import models, transforms
# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True)
# 冻结模型的所有层,以防改变预训练权重
for param in model.parameters():
param.requires_grad = False
# 取出最后一层全连接层,因为我们要替换它以适应新的类别数
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, num_classes) # num_classes是你分类任务的类别总数
# 数据预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 假设你有一个名为'data_loader'的批次加载器
images, labels = data_loader[0] # 获取一批图像和对应的标签
# 将输入数据转换成Float并归一化
inputs = images.float().cuda() / 255
# 前向传播并获取预测结果
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
```
这里的关键点是在模型的最后一层替换为一个新的线性层,以匹配你的目标分类类别,并在前传之前对数据进行预处理。
阅读全文