用python写一个能输出图形函数的线性分类器完整代码
时间: 2024-05-16 10:13:59 浏览: 137
以下是一个用Python编写的线性分类器的代码:
```python
import numpy as np
import matplotlib.pyplot as plt
class LinearClassifier:
def __init__(self, input_size, output_size):
self.W = np.random.randn(input_size, output_size)
self.b = np.zeros((1, output_size))
def forward(self, X):
return np.dot(X, self.W) + self.b
def backward(self, X, y, y_pred, learning_rate):
m = X.shape[0]
delta = y_pred - y
dW = (1/m) * np.dot(X.T, delta)
db = (1/m) * np.sum(delta, axis=0, keepdims=True)
self.W -= learning_rate * dW
self.b -= learning_rate * db
def train(self, X, y, learning_rate=0.1, epochs=100):
for epoch in range(epochs):
y_pred = self.forward(X)
self.backward(X, y, y_pred, learning_rate)
def predict(self, X):
return np.argmax(self.forward(X), axis=1)
# Generate random data
np.random.seed(0)
X = np.random.randn(100, 2)
y = np.array([0]*50 + [1]*50)
# Train classifier
clf = LinearClassifier(2, 2)
clf.train(X, y)
# Plot decision boundary
x1 = np.linspace(-3, 3, 100)
x2 = np.linspace(-3, 3, 100)
xx1, xx2 = np.meshgrid(x1, x2)
X_grid = np.array([xx1.ravel(), xx2.ravel()]).T
y_grid = clf.predict(X_grid)
y_grid = y_grid.reshape(xx1.shape)
plt.contourf(xx1, xx2, y_grid, cmap=plt.cm.Spectral, alpha=0.8)
plt.scatter(X[:,0], X[:,1], c=y, cmap=plt.cm.Spectral)
plt.show()
```
这个代码包含一个名为 `LinearClassifier` 的类,用于训练和预测线性分类器。在这个例子中,我们生成了一个二维的随机数据集,其中包含50个类别为0的点和50个类别为1的点。然后,我们使用 `LinearClassifier` 类来训练一个线性分类器,并使用 `matplotlib` 库绘制决策边界。
阅读全文