torch.load怎么用
时间: 2025-01-01 22:22:33 浏览: 6
### 使用 `torch.load` 加载模型
#### 负载模型参数
为了加载之前保存的模型参数,可以使用 `torch.load()` 函数。此函数不仅限于加载模型权重;还可以用于恢复整个训练状态,包括优化器的状态字典等[^1]。
```python
import torch
# 假设已经有一个定义好的模型实例 model
state_dict = torch.load('model_weights.pth', map_location=torch.device('cpu')) # 将模型加载至 CPU 上
model.load_state_dict(state_dict) # 将加载的权重复制到模型中
```
这段代码展示了如何从文件 `'model_weights.pth'` 中读取先前存储的模型权重并将其应用于新的模型实例。这里特别指定了 `map_location` 参数来确保即使是在不同设备(如 GPU 或 CPU)之间移动时也能正确加载数据[^2]。
#### 处理不同的硬件环境
当在不同于原始保存位置的不同计算资源上运行程序时——比如在一个只有 CPU 的环境中加载由 GPU 训练得到的数据——可以通过指定 `map_location` 来解决潜在的问题:
```python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loaded_model = TheModelClass(*args, **kwargs) # 创建一个新的模型对象
loaded_model.load_state_dict(torch.load(PATH, map_location=device))
loaded_model.to(device)
```
上述片段会检查当前是否有可用 CUDA 设备,并据此调整加载行为以适应目标平台的要求。
#### 自定义解码选项
对于某些特殊情况下的旧版本 PyTorch 文件或其他编码方式保存的对象,可能需要提供额外的关键字参数给底层使用的 Pickle 解析工具:
```python
custom_params = {'encoding': 'latin1'} # 对应特定情况所需的自定义参数
data = torch.load('legacy_data.pkl', pickle_load_args=custom_params)
```
这允许更灵活地处理各种类型的序列化数据,特别是那些采用非默认设置创建的情况[^3]。
阅读全文