pytorch如何加载训练好的模型进行分类评估
时间: 2023-09-12 13:11:26 浏览: 112
在 PyTorch 中,可以使用 `torch.load()` 方法加载已经训练好的模型参数。一般来说,我们需要加载模型的参数和模型的结构,因此需要先定义好模型的结构,然后再将训练好的参数加载进来。接下来,我们可以用加载好的模型对测试集进行分类评估。
以下是一个简单的示例代码:
```python
import torch
import torch.nn as nn
# 定义模型结构
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 2)
def forward(self, x):
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
# 加载模型参数
model_path = 'model.pt'
model = Model()
model.load_state_dict(torch.load(model_path))
# 对测试集进行分类评估
test_data = torch.randn(100, 10)
test_labels = torch.randint(0, 2, (100,))
with torch.no_grad():
outputs = model(test_data)
_, predicted = torch.max(outputs.data, 1)
accuracy = (predicted == test_labels).sum().item() / test_labels.size(0)
print('Accuracy: %.2f%%' % (accuracy * 100))
```
在上面的示例中,我们首先定义了一个简单的模型结构 `Model`,包含两个全连接层。然后,我们使用 `torch.load()` 方法加载已经训练好的模型参数,并将其加载到模型中。最后,我们使用加载好的模型对测试集进行分类评估,计算准确率并打印输出。
阅读全文