torch.load weight_only
时间: 2025-01-01 14:13:54 浏览: 5
### 加载PyTorch模型权重
为了仅加载PyTorch模型的权重而不加载整个模型结构,可以使用`torch.load`函数读取存储在`.pth`或`.pt`文件中的权重字典。之后通过调用模型实例上的`load_state_dict()`方法来更新模型参数。
```python
import torch
from torchvision import models
# 假设已经定义了一个与原模型架构相同的模型对象model
model = models.resnet18()
# 如果是在CPU上运行,则指定map_location为'cpu'
checkpoint = torch.load('path_to_model_weights.pth', map_location='cpu')
# 只加载模型的状态字典(即只加载权重)
model.load_state_dict(checkpoint)
if torch.cuda.is_available():
model.to('cuda')
```
当处理由多个组件组成的复杂检查点时——比如不仅包含模型状态还包括优化器状态和其他元数据的情况,应该先提取出实际的模型状态字典再传递给`load_state_dict()`[^3]:
```python
# 对于更复杂的检查点文件,可能需要这样做
checkpoint = torch.load('complex_checkpoint.pth.tar', map_location='cpu')
model.load_state_dict(checkpoint['state_dict']) # 或者 checkpoint['model']
```
阅读全文