二分类问题中的交叉熵损失函数
时间: 2023-11-23 08:55:28 浏览: 150
二分类问题中的交叉熵损失函数是一种常用的损失函数,它可以用来衡量模型输出的概率分布与真实标签的差异。其公式如下:
$$
J(\theta)=-\frac{1}{m}\sum_{i=1}^{m}[y^{(i)}\log(h_{\theta}(x^{(i)}))+(1-y^{(i)})\log(1-h_{\theta}(x^{(i)}))]
$$
其中,$m$表示样本数量,$y^{(i)}$表示第$i$个样本的真实标签(0或1),$h_{\theta}(x^{(i)})$表示模型对第$i$个样本的预测概率,$\theta$表示模型的参数。
交叉熵损失函数的含义是,对于每个样本,如果真实标签为1,则希望模型输出的概率也越接近1越好;如果真实标签为0,则希望模型输出的概率也越接近0越好。同时,交叉熵损失函数也具有良好的数学性质,可以通过梯度下降等优化算法来求解模型参数。
下面是一个使用交叉熵损失函数训练二分类模型的示例代码:
```python
import numpy as np
# 定义sigmoid函数
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# 定义交叉熵损失函数
def cross_entropy_loss(y_true, y_pred):
epsilon = 1e-7 # 避免log(0)的情况
return -np.mean(y_true * np.log(y_pred + epsilon) + (1 - y_true) * np.log(1 - y_pred + epsilon))
# 定义模型类
class LogisticRegression:
def __init__(self, lr=0.01, num_iter=100000, fit_intercept=True):
self.lr = lr # 学习率
self.num_iter = num_iter # 迭代次数
self.fit_intercept = fit_intercept # 是否拟合截距
self.theta = None # 模型参数
def fit(self, X, y):
if self.fit_intercept:
X = np.hstack([np.ones((X.shape[0], 1)), X]) # 添加一列全为1的特征,用于拟合截距
self.theta = np.zeros(X.shape[1]) # 初始化模型参数为0
for i in range(self.num_iter):
z = np.dot(X, self.theta) # 计算z值
h = sigmoid(z) # 计算预测概率
gradient = np.dot(X.T, (h - y)) / y.size # 计算梯度
self.theta -= self.lr * gradient # 更新模型参数
def predict_proba(self, X):
if self.fit_intercept:
X = np.hstack([np.ones((X.shape[0], 1)), X]) # 添加一列全为1的特征,用于拟合截距
return sigmoid(np.dot(X, self.theta)) # 计算预测概率
def predict(self, X, threshold=0.5):
return (self.predict_proba(X) >= threshold).astype(int) # 根据阈值将概率转换为类别
# 使用sklearn生成二分类数据集
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, n_redundant=0, random_state=42)
# 划分训练集和测试集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练模型
model = LogisticRegression(lr=0.1, num_iter=10000)
model.fit(X_train, y_train)
# 在测试集上评估模型
y_pred = model.predict(X_test)
print("Accuracy:", np.mean(y_pred == y_test))
print("Cross-entropy loss:", cross_entropy_loss(y_test, model.predict_proba(X_test)))
```
阅读全文