余弦损失python代码
时间: 2023-10-28 15:00:14 浏览: 141
余弦相似性计算及python代码实现过程解析
下面是一个示例的余弦损失函数的Python代码:
```python
import tensorflow as tf
import numpy as np
def cosine_loss(y_true, y_pred):
y_true = tf.math.l2_normalize(y_true, axis=-1)
y_pred = tf.math.l2_normalize(y_pred, axis=-1)
loss = tf.reduce_mean(tf.keras.losses.cosine_similarity(y_true, y_pred))
return loss
# 用法示例
y_true = np.array([[1, 2, 3], [4, 5, 6]])
y_pred = np.array([[1, 2, 3], [0, 0, 0]])
loss = cosine_loss(y_true, y_pred)
print(loss)
```
在上面的示例中,`y_true` 是一个张量,表示真实的标签值,`y_pred` 是一个张量,表示模型的预测值。该代码使用 TensorFlow 实现了余弦损失函数,通过调用 `tf.keras.losses.cosine_similarity` 计算余弦相似度,并取均值作为损失值。在示例中,`y_true` 和 `y_pred` 分别是两个样本的标签和预测值。
阅读全文