写一个test loss代码
时间: 2024-03-01 08:52:06 浏览: 29
要写一个 Test Loss 的代码,需要先有一个模型和测试数据集。以下是一个简单的示例代码,可以计算出在测试数据集上的平均损失值:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# 定义模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载测试数据集
test_data = [(torch.randn(10), torch.randn(1)) for _ in range(100)]
test_loader = DataLoader(test_data, batch_size=10)
# 加载模型
model = MyModel()
model.load_state_dict(torch.load('model.pth'))
# 定义损失函数
criterion = nn.MSELoss()
# 计算测试集上的平均损失
test_loss = 0.0
with torch.no_grad():
model.eval()
for inputs, targets in test_loader:
outputs = model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item() * inputs.size(0)
test_loss /= len(test_data)
print(f"Test Loss: {test_loss:.4f}")
```
在这个示例代码中,我们定义了一个简单的模型 `MyModel`,并加载了一个测试数据集 `test_data`。在计算测试集上的损失时,我们使用了 PyTorch 中的 `DataLoader` 来批量加载数据。我们使用了 `nn.MSELoss()` 作为损失函数,计算每个批次的损失,最后求平均值得到测试集上的平均损失值。