torch下载预训练模型
时间: 2025-01-02 10:42:15 浏览: 12
### 如何在 PyTorch 中下载预训练模型
#### 使用 Torchvision 下载预训练模型
PyTorch 的 `torchvision` 库提供了一个简单的方法来获取预训练模型。对于大多数常见的计算机视觉任务,可以使用如下方式直接调用并加载预训练权重:
```python
import torchvision.models as models
# 获取 ResNet-18 预训练模型实例
model = models.resnet18(pretrained=True)
# 将模型设置为评估模式
model.eval()
```
这段代码会自动从互联网下载对应的预训练权重文件,并初始化给定架构的神经网络[^1]。
#### 修改默认下载路径
有时为了管理方便或者节省空间,希望改变这些大体积模型文件存储的位置。可以通过环境变量 `TORCH_MODEL_ZOO` 或者更推荐的方式是利用 `TORCH_HOME` 来指定自定义目录作为缓存位置:
```bash
export TORCH_HOME=/path/to/custom/cache/dir
```
之后所有的预训练模型都会被保存在这个新设定的地方[^3]。
#### 手动下载与加载特定版本的预训练模型
如果需要离线工作或是想要固定使用的具体版本,则可以选择先手动下载 `.pth` 文件再加载。例如针对 VGG19 模型的操作如下所示:
```python
from collections import OrderedDict
import torch
state_dict = torch.load('vgg19-dcbb9e9d.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '') # 去除多GPU训练时附加的名字前缀
new_state_dict[name] = v
pretrained_model = models.vgg19()
pretrained_model.load_state_dict(new_state_dict)
pretrained_model.eval() # 切换到推理模式
```
此方法允许用户精确控制所依赖的具体模型参数文件[^5]。
阅读全文