resnet18网络模型代码python
时间: 2024-08-30 17:00:41 浏览: 49
Alexnet-resnet18 预训练模型pth
5星 · 资源好评率100%
ResNet18是一种深度残差网络(Residual Network),常用于计算机视觉任务,如图像分类。它是深度学习库PyTorch中的预训练模型之一。以下是使用Python和PyTorch库加载和展示ResNet18的基本代码:
```python
import torch
from torchvision import models, transforms
# 定义预处理步骤
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True)
# 模型在GPU上运行,如果可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 示例:前向传播并查看模型结构
input_image = transform(image_path) # image_path替换成实际的图片路径
input_image = input_image.unsqueeze(0).to(device)
output = model(input_image)
print(f"Model architecture: {model}")
```
在这个例子中,我们首先导入必要的模块,然后定义了数据预处理流程。接着加载预训练的ResNet18模型,并将其移动到设备(通常是GPU)。最后,我们对一个输入图像进行了前向传播,并打印出了模型的结构。
如果你想要了解如何训练这个模型或者使用它的特征层,可以查阅相关的PyTorch教程或者官方文档。
阅读全文