class LogisticRegression(object): def __init__(self, input_size, output_size, eta, max_epoch, eps):
时间: 2023-12-02 17:05:59 浏览: 98
def __init__(self, input_size, output_size, eta=0.01, max_epoch=1000, eps=1e-7):
"""
Constructor for LogisticRegression class.
:param input_size: number of features in the input
:param output_size: number of classes in the output
:param eta: learning rate for gradient descent
:param max_epoch: maximum number of epochs for training
:param eps: small value to prevent division by zero
"""
self.input_size = input_size
self.output_size = output_size
self.eta = eta
self.max_epoch = max_epoch
self.eps = eps
self.weights = None
self.bias = None
def fit(self, X, y):
"""
Train the logistic regression model on the given training data.
:param X: input training data of shape (n_samples, n_features)
:param y: output training data of shape (n_samples, n_classes)
"""
n_samples, n_features = X.shape
_, n_classes = y.shape
self.weights = np.zeros((n_features, n_classes))
self.bias = np.zeros((1, n_classes))
for epoch in range(self.max_epoch):
# Forward pass
z = np.dot(X, self.weights) + self.bias
y_pred = self.softmax(z)
# Backward pass
error = y_pred - y
grad_weights = np.dot(X.T, error)
grad_bias = np.sum(error, axis=0, keepdims=True)
# Update weights and bias
self.weights -= self.eta * grad_weights
self.bias -= self.eta * grad_bias
# Check for convergence
if np.abs(grad_weights).max() < self.eps:
break
def predict(self, X):
"""
Predict the output for the given input data.
:param X: input data of shape (n_samples, n_features)
:return: predicted output of shape (n_samples, n_classes)
"""
z = np.dot(X, self.weights) + self.bias
y_pred = self.softmax(z)
return y_pred
def softmax(self, z):
"""
Apply the softmax function to the given input.
:param z: input data of shape (n_samples, n_classes)
:return: output data of shape (n_samples, n_classes)
"""
exp_z = np.exp(z - np.max(z, axis=1, keepdims=True))
return exp_z / np.sum(exp_z, axis=1, keepdims=True)
阅读全文
相关推荐
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)