基于pytorch的批次大小为50的五分类混淆矩阵代码
时间: 2024-02-11 14:08:39 浏览: 77
利用pytorch实现图像分类的一个完整的代码
下面是一个基于PyTorch的批次大小为50的五分类混淆矩阵的代码示例:
```python
import torch
import numpy as np
# 假设我们有50个样本,每个样本有5个类别
batch_size = 50
num_classes = 5
# 生成随机的真实标签和预测标签
y_true = torch.randint(low=0, high=num_classes, size=(batch_size,))
y_pred = torch.randint(low=0, high=num_classes, size=(batch_size,))
# 将真实标签和预测标签转换为numpy数组
y_true = y_true.numpy()
y_pred = y_pred.numpy()
# 计算混淆矩阵
confusion_matrix = np.zeros((num_classes, num_classes))
for i in range(batch_size):
confusion_matrix[y_true[i]][y_pred[i]] += 1
print(confusion_matrix)
```
此代码将生成一个大小为5x5的混淆矩阵,并将其打印出来。你可以根据自己的需要修改代码中的`batch_size`和`num_classes`变量,以适应不同的数据集和分类任务。
阅读全文