用bp算法写一个鸢尾花分类的代码
时间: 2023-12-06 09:00:32 浏览: 91
鸢尾花分类是机器学习领域中一个经典的问题,可以使用BP算法(反向传播算法)来解决。BP算法是一种用于训练人工神经网络的方法,它能够根据训练集的样本数据来调整网络的权重和偏差,从而达到对未知数据进行分类的目的。
下面是使用BP算法进行鸢尾花分类的代码示例:
1. 导入所需的库和数据集:
```python
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris = load_iris()
data = iris.data
target = iris.target
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=0)
```
2. 定义激活函数和神经网络模型:
```python
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def sigmoid_derivative(x):
return x * (1 - x)
class NeuralNetwork:
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.weights1 = np.random.randn(self.input_size, self.hidden_size)
self.weights2 = np.random.randn(self.hidden_size, self.output_size)
self.bias1 = np.random.randn(1, self.hidden_size)
self.bias2 = np.random.randn(1, self.output_size)
def forward(self, X):
self.hidden_layer = sigmoid(np.dot(X, self.weights1) + self.bias1)
self.output_layer = sigmoid(np.dot(self.hidden_layer, self.weights2) + self.bias2)
def backward(self, X, y):
delta_output = (y - self.output_layer) * sigmoid_derivative(self.output_layer)
delta_hidden = np.dot(delta_output, self.weights2.T) * sigmoid_derivative(self.hidden_layer)
self.weights2 += np.dot(self.hidden_layer.T, delta_output)
self.weights1 += np.dot(X.T, delta_hidden)
self.bias2 += np.sum(delta_output, axis=0)
self.bias1 += np.sum(delta_hidden, axis=0)
def train(self, X, y, epochs):
for i in range(epochs):
self.forward(X)
self.backward(X, y)
def predict(self, X):
self.forward(X)
return np.argmax(self.output_layer, axis=1)
```
3. 创建神经网络模型并进行训练和预测:
```python
input_size = X_train.shape[1]
hidden_size = 8 # 设置隐藏层节点数
output_size = len(np.unique(y_train)) # 输出层节点数
epochs = 1000
model = NeuralNetwork(input_size, hidden_size, output_size)
model.train(X_train, y_train, epochs)
y_pred = model.predict(X_test)
```
以上就是使用BP算法编写鸢尾花分类的代码示例,通过训练神经网络模型,可以使用测试集数据进行预测,并得到预测结果。
阅读全文