python中 mini-batch梯度下降的代码是什么
时间: 2024-12-25 13:27:48 浏览: 4
在Python中,Mini-Batch梯度下降是一种优化算法,它将数据集分割成小批量(mini-batches),然后计算每个批次的梯度并更新模型参数。这里是一个简单的示例,假设我们有一个线性回归模型,使用Scikit-Learn库:
```python
import numpy as np
class LinearRegression:
def __init__(self, learning_rate=0.01, batch_size=32):
self.learning_rate = learning_rate
self.batch_size = batch_size
self.weights = None
self.bias = None
def fit(self, X, y):
m, n = X.shape
self.weights = np.zeros(n)
self.bias = 0
# 遍历整个数据集,分批训练
for i in range(0, m, self.batch_size):
batch_X = X[i:i + self.batch_size]
batch_y = y[i:i + self.batch_size]
# 梯度计算和参数更新
gradients = (2 / batch_X.shape[0]) * (batch_X.T @ (batch_y - self.predict(batch_X)))
self.weights -= self.learning_rate * gradients[:, 0] # 更新权重
self.bias -= self.learning_rate * gradients[0, 0] # 更新偏置
def predict(self, X):
return X @ self.weights + self.bias
# 使用示例
X = ... # 输入特征矩阵
y = ... # 输出标签向量
model = LinearRegression()
model.fit(X, y)
```
在这个例子中,`fit`方法执行了 Mini-Batch梯度下降,每次迭代处理数据的一个小批次。注意,这只是一个基础版本,实际应用可能需要更复杂的循环控制、数据预处理以及学习率调整等。
阅读全文