torch.load cpu
时间: 2025-01-02 14:44:09 浏览: 13
### 使用 `torch.load` 在 CPU 上加载 PyTorch 模型或张量
当需要在 CPU 上加载由 GPU 训练的模型或保存的数据时,可以利用 `map_location` 参数来指定设备。这使得即使原始模型是在 GPU 上训练的,在缺乏 GPU 资源的情况下也能顺利加载。
对于仅限于 CPU 的环境而言,可以通过设置 `map_location=torch.device('cpu')` 来确保所有 tensors 都被映射到 CPU 设备上[^1]:
```python
import torch
# 假设 'model_checkpoint.pth' 是要加载的文件路径
checkpoint = torch.load('model_checkpoint.pth', map_location=torch.device('cpu'))
```
如果目标是从已保存的状态字典恢复整个模型实例,则除了上述操作外还需要创建一个新的模型对象并调用其 `.load_state_dict()` 方法传入状态字典[^2]:
```python
from my_model_definition import MyModel # 导入定义好的网络结构类MyModel
model = MyModel()
model.load_state_dict(checkpoint['state_dict']) # 如果 checkpoint 中存储的是 state dict
```
另外一种情况是直接加载一个简单的 tensor 或者其他 Python 对象而不仅仅是神经网络权重;这时同样适用上面提到的方式使用 `torch.load()` 函数加上合适的 `map_location` 参数即可完成加载工作.
阅读全文