如何在Python中加载已保存的模型?
时间: 2024-10-16 12:12:09 浏览: 69
python使用tensorflow保存、加载和使用模型的方法
5星 · 资源好评率100%
在Python中,如果你已经训练了一个模型并将其保存下来,通常使用深度学习库如TensorFlow或PyTorch,你可以通过以下步骤加载模型:
1. **TensorFlow (Keras)**:
- 使用`tf.keras.models.load_model()`函数,例如:
```python
model = tf.keras.models.load_model('path_to_your_saved_model.h5')
```
2. **PyTorch**:
- 如果是torch.nn.Module类型的模型,可以使用`torch.load()`:
```python
model = torch.load('path_to_your_saved_model.pth', map_location='cpu')
```
- 对于基于state_dict的模型,可以直接加载:
```python
model_dict = torch.load('path_to_your_saved_model.pth', map_location='cpu')['model_state']
model.load_state_dict(model_dict)
```
3. **Hugging Face Transformers**:
- 对于预训练模型,可以使用`AutoModel.from_pretrained()`:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained('path/to/saved/model')
```
确保在加载之前,模型文件路径是正确的,并且加载时需要考虑是否需要将模型从GPU迁移到CPU(使用`map_location='cpu'`)。
阅读全文