只使用 numpy 编写逻辑回归算法,对 iris 数据进行多分类并可视化
时间: 2024-02-09 10:13:11 浏览: 67
好的,让我来回答你的问题。
首先,需要导入需要用到的库,包括:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
```
然后,我们需要加载 iris 数据集,并对数据进行预处理:
```python
iris = load_iris()
X = iris.data
y = iris.target
y = y.reshape(-1, 1)
# 对 y 进行 one-hot 编码
m = len(y)
n = np.max(y) + 1
y_one_hot = np.zeros((m, n))
y_one_hot[np.arange(m), y.reshape(-1)] = 1
```
接下来,我们需要定义逻辑回归模型:
```python
class LogisticRegression:
def __init__(self, input_dim, output_dim, lr=0.01):
self.W = np.zeros((input_dim, output_dim))
self.b = np.zeros((1, output_dim))
self.lr = lr
def sigmoid(self, z):
return 1 / (1 + np.exp(-z))
def forward(self, X):
z = np.dot(X, self.W) + self.b
a = self.sigmoid(z)
return a
def backward(self, X, y, a):
m = X.shape[0]
dz = a - y
dW = np.dot(X.T, dz) / m
db = np.sum(dz, axis=0, keepdims=True) / m
return dW, db
def update(self, dW, db):
self.W = self.W - self.lr * dW
self.b = self.b - self.lr * db
def train(self, X, y, num_iters):
for i in range(num_iters):
a = self.forward(X)
dW, db = self.backward(X, y, a)
self.update(dW, db)
def predict(self, X):
a = self.forward(X)
y_pred = np.argmax(a, axis=1)
return y_pred
```
然后,我们需要将数据集分为训练集和测试集:
```python
X_train, X_test, y_train, y_test = train_test_split(X, y_one_hot, test_size=0.2)
```
接下来,我们需要训练模型:
```python
input_dim = X_train.shape[1]
output_dim = y_train.shape[1]
lr = 0.1
num_iters = 1000
model = LogisticRegression(input_dim, output_dim, lr)
model.train(X_train, y_train, num_iters)
```
最后,我们可以绘制出决策边界和数据点的散点图来可视化分类效果:
```python
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral)
plt.show()
```
以上就是使用 numpy 编写逻辑回归算法对 iris 数据进行多分类并可视化的完整代码。
阅读全文