基于pytorch,帮我写一段代码用于实现混淆矩阵。对分类神经网络模型进行评估
时间: 2023-03-28 07:04:24 浏览: 110
混淆矩阵分类性能评价及Python实现.pdf
好的,以下是基于 PyTorch 实现混淆矩阵的代码:
```python
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
# 定义分类神经网络模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(10, 5)
self.fc2 = torch.nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 定义评估函数
def evaluate(model, dataloader):
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for inputs, labels in dataloader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
y_true.extend(labels.numpy())
y_pred.extend(predicted.numpy())
return y_true, y_pred
# 生成数据
X = np.random.rand(100, 10)
y = np.random.randint(, 2, 100)
# 转换为 PyTorch 的 Tensor
X_tensor = torch.from_numpy(X).float()
y_tensor = torch.from_numpy(y).long()
# 定义数据集和数据加载器
dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
# 初始化模型
model = Net()
# 评估模型
y_true, y_pred = evaluate(model, dataloader)
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
print(cm)
```
以上代码可以对分类神经网络模型进行评估,并输出混淆矩阵。
阅读全文