class LogisticRegression(object): def __init__(self, input_size, output_size, eta, max_epoch, eps):
时间: 2023-12-02 07:05:05 浏览: 109
class LogisticRegression:
def __init__(self, input_size, output_size, eta=0.01, max_epoch=1000, eps=1e-5):
"""
Constructor for the LogisticRegression class.
:param input_size: int, size of the input data
:param output_size: int, number of output classes
:param eta: float, learning rate (default=0.01)
:param max_epoch: int, maximum number of epochs (default=1000)
:param eps: float, convergence threshold (default=1e-5)
"""
self.input_size = input_size
self.output_size = output_size
self.eta = eta
self.max_epoch = max_epoch
self.eps = eps
self.weights = None
self.bias = None
def fit(self, X, y):
"""
Fit the logistic regression model on the training data.
:param X: numpy array of shape (n_samples, input_size), input data
:param y: numpy array of shape (n_samples,), target labels
"""
n_samples = X.shape[0]
self.weights = np.zeros((self.input_size, self.output_size))
self.bias = np.zeros((1, self.output_size))
for epoch in range(self.max_epoch):
# Forward pass
z = np.dot(X, self.weights) + self.bias
y_pred = self._softmax(z)
# Backward pass
error = y_pred - self._onehot(y, self.output_size)
grad_weights = np.dot(X.T, error)
grad_bias = np.sum(error, axis=0, keepdims=True)
# Update weights and bias
self.weights -= self.eta * grad_weights
self.bias -= self.eta * grad_bias
# Check for convergence
if np.max(np.abs(grad_weights)) < self.eps:
break
def predict(self, X):
"""
Predict the class labels for the input data.
:param X: numpy array of shape (n_samples, input_size), input data
:return: numpy array of shape (n_samples,), predicted class labels
"""
z = np.dot(X, self.weights) + self.bias
y_pred = self._softmax(z)
return np.argmax(y_pred, axis=1)
def _softmax(self, z):
"""
Compute the softmax activation function.
:param z: numpy array of shape (n_samples, output_size), input data
:return: numpy array of shape (n_samples, output_size), output of the softmax function
"""
exp_z = np.exp(z - np.max(z, axis=1, keepdims=True))
return exp_z / np.sum(exp_z, axis=1, keepdims=True)
def _onehot(self, y, output_size):
"""
Convert the target labels to one-hot encoding.
:param y: numpy array of shape (n_samples,), target labels
:param output_size: int, number of output classes
:return: numpy array of shape (n_samples, output_size), one-hot encoding of the target labels
"""
onehot = np.zeros((len(y), output_size))
onehot[np.arange(len(y)), y] = 1
return onehot
阅读全文
相关推荐
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)