随机生成三类数据,每类20个样本,样本为二维向量,不使用sklearn且不封装的条件下构建一个隐藏层为4的BP神经网络进行分类并完成可视化
时间: 2024-02-22 09:59:08 浏览: 169
好的,我可以为您提供一份Python代码来实现这个任务。代码如下:
```python
import numpy as np
import matplotlib.pyplot as plt
# 随机生成三类数据
np.random.seed(0)
# 第一类数据
class1 = np.random.randn(20, 2) + np.array([0, 10])
# 第二类数据
class2 = np.random.randn(20, 2) + np.array([5, 5])
# 第三类数据
class3 = np.random.randn(20, 2) + np.array([10, 0])
# 将数据合并
data = np.vstack([class1, class2, class3])
# 构建标签
labels = np.array([0]*20 + [1]*20 + [2]*20)
# 打乱数据
indices = np.random.permutation(len(data))
data = data[indices]
labels = labels[indices]
# 构建神经网络模型
class NeuralNetwork:
def __init__(self):
self.inputSize = 2
self.outputSize = 3
self.hiddenSize = 4
self.W1 = np.random.randn(self.inputSize, self.hiddenSize)
self.W2 = np.random.randn(self.hiddenSize, self.outputSize)
def forward(self, X):
self.z = np.dot(X, self.W1)
self.z2 = self.sigmoid(self.z)
self.z3 = np.dot(self.z2, self.W2)
o = self.sigmoid(self.z3)
return o
def sigmoid(self, s):
return 1 / (1 + np.exp(-s))
def sigmoidPrime(self, s):
return s * (1 - s)
def backward(self, X, y, o):
self.o_error = y - o
self.o_delta = self.o_error * self.sigmoidPrime(o)
self.z2_error = self.o_delta.dot(self.W2.T)
self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2)
self.W1 += X.T.dot(self.z2_delta)
self.W2 += self.z2.T.dot(self.o_delta)
def train(self, X, y):
o = self.forward(X)
self.backward(X, y, o)
# 训练神经网络
NN = NeuralNetwork()
for i in range(1000):
NN.train(data, np.eye(3)[labels])
# 生成网格点
x_min, x_max = data[:, 0].min() - 1, data[:, 0].max() + 1
y_min, y_max = data[:, 1].min() - 1, data[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
# 预测网格点的分类
Z = NN.forward(np.c_[xx.ravel(), yy.ravel()])
Z = np.argmax(Z, axis=1)
Z = Z.reshape(xx.shape)
# 可视化分类结果
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(data[:, 0], data[:, 1], c=labels, s=20, edgecolor='k')
plt.show()
```
这段代码与上面的代码类似,不同之处在于数据是二维向量。我使用了相同的`NeuralNetwork`类来构建神经网络模型,使用了相同的随机梯度下降算法来训练模型。接着,我生成了一个网格点,用于可视化分类结果。最后,我使用`contourf()`函数绘制了分类结果的等高线图,使用`scatter()`函数绘制了原始数据的散点图。运行代码后,您将会看到一个可视化的分类结果。
阅读全文