画交叉熵损失函数loss图的代码
时间: 2023-09-17 11:01:14 浏览: 161
画交叉熵损失函数的代码如下:
```python
import numpy as np
import matplotlib.pyplot as plt
def cross_entropy_loss(y, y_pred):
epsilon = 1e-10
loss = -np.sum(y * np.log(y_pred + epsilon))
return loss
def plot_loss():
y_pred = np.linspace(0.001, 1, 100)
y = np.ones(100)
losses = []
for pred in y_pred:
loss = cross_entropy_loss(y, pred)
losses.append(loss)
plt.plot(y_pred, losses)
plt.xlabel('Predicted Value')
plt.ylabel('Loss')
plt.title('Cross Entropy Loss')
plt.show()
plot_loss()
```
上述代码中,`cross_entropy_loss` 函数计算交叉熵损失,其中 `y` 是真实标签,`y_pred` 是预测标签。函数中,使用了一个小的 epsilon 值来防止 `np.log` 函数的输入为 0。`plot_loss` 函数生成了 100 个预测标签 `y_pred`,并计算每个预测标签对应的交叉熵损失。最后,使用 `plt.plot` 函数绘制出交叉熵损失与预测值的关系图,并添加了横轴和纵轴的标签和标题。最后通过 `plt.show` 函数显示图像。
阅读全文