以初学者角度编写一个感知机分类mnsit数据集代码
时间: 2023-05-30 19:05:06 浏览: 69
以下是一个基本的感知机分类MNIST数据集的例子,适合初学者理解:
```python
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 将图像数据展平成一维数组
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
# 将像素值归一化到[0, 1]区间
x_train = x_train / 255
x_test = x_test / 255
# 将标签转换为0和1
y_train = np.where(y_train == 0, 1, -1)
y_test = np.where(y_test == 0, 1, -1)
# 定义感知机模型
class Perceptron(object):
def __init__(self, learning_rate=0.1, max_iter=1000):
self.learning_rate = learning_rate
self.max_iter = max_iter
def fit(self, X, y):
self.w = np.zeros(X.shape[1])
self.b = 0
for i in range(self.max_iter):
for xi, yi in zip(X, y):
if yi * (np.dot(xi, self.w) + self.b) <= 0:
self.w += self.learning_rate * yi * xi
self.b += self.learning_rate * yi
def predict(self, X):
return np.sign(np.dot(X, self.w) + self.b)
# 训练感知机模型
model = Perceptron()
model.fit(x_train, y_train)
# 在测试集上评估模型
acc = np.mean(model.predict(x_test) == y_test)
print(f"Accuracy: {acc}")
# 可视化权重
plt.imshow(model.w.reshape(28, 28), cmap="gray")
plt.axis("off")
plt.show()
```
此代码使用MNIST数据集来训练一个感知机模型,将图像数据展平为一个一维数组,并将像素值归一化到[0, 1]区间。然后,将标签转换为0和1,其中1表示数字“0”。定义了一个`Perceptron`类来实现感知机模型。在训练期间,使用`fit`方法来更新权重和偏置,直到达到最大迭代次数或训练误差为零为止。最后,使用`predict`方法在测试集上进行预测,并计算模型的准确度。还可视化模型的权重。