torchvision.models,
时间: 2025-01-04 13:35:27 浏览: 19
### Torchvision Models 使用指南
Torchvision 是 PyTorch 的官方计算机视觉库,提供了多种预训练模型以及常用的数据集加载工具。`torchvision.models` 模块包含了多个经典的卷积神经网络架构,这些模型可以直接用于推理或作为迁移学习的基础。
#### 可用模型列表
以下是 `torchvision.models` 中一些常用的模型:
- **AlexNet**: 早期的经典CNN结构之一。
- **VGG**: 提供不同层数的选择(如 VGG11, VGG13, VGG16 和 VGG19),具有简单的堆叠层设计[^1]。
- **ResNet**: 包含残差连接机制,有效解决了深层网络中的梯度消失问题;有 ResNet18, ResNet34, ResNet50 等版本可供选择。
- **DenseNet**: 特征重用率高,能够实现更高效的前向传播路径。
- **SqueezeNet**: 设计紧凑,在保持较高精度的同时减少了参数数量。
- **MobileNet v2/v3**: 面向移动设备优化的小型高效模型。
- **EfficientNet**: 利用了复合缩放的方法来平衡宽度、深度和分辨率之间的关系。
- **RegNet**: 基于规则化图理论构建的一族通用 CNN 架构。
- **Vision Transformers (ViT)**: 将 Transformer 应用于图像分类任务的新颖尝试。
#### 获取并使用预训练模型
要获得上述任何一个预训练好的模型实例非常简单,只需调用相应函数即可。下面是一个例子展示如何加载带有 ImageNet 权重的 ResNet50 并对其进行预测操作:
```python
import torch
from torchvision import models, transforms
from PIL import Image
# 加载预训练的 resnet50 模型
model = models.resnet50(pretrained=True)
# 设置为评估模式
model.eval()
# 定义输入图片转换方式
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 打开一张测试图片
image_path = 'example.jpg'
input_image = Image.open(image_path).convert('RGB')
# 对图片做变换处理
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # 创建批次维度
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad(): # 关闭自动求导功能以节省内存
output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities.argmax().item())
```
此代码片段展示了从准备数据到最终得到类别标签的过程。注意这里假设已经安装好了必要的依赖项,并且有一个名为 example.jpg 的本地图片文件存在。
阅读全文