四层bp神经网络代码
时间: 2023-10-09 18:10:03 浏览: 50
下面是一个四层的BP神经网络的代码示例:
```python
# 导入所需的库
import numpy as np
# 定义激活函数
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# 定义BP神经网络类
class BPNeuralNetwork:
def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
self.input_dim = input_dim
self.hidden_dim1 = hidden_dim1
self.hidden_dim2 = hidden_dim2
self.output_dim = output_dim
# 初始化权重
self.weights1 = np.random.randn(self.input_dim, self.hidden_dim1)
self.weights2 = np.random.randn(self.hidden_dim1, self.hidden_dim2)
self.weights3 = np.random.randn(self.hidden_dim2, self.output_dim)
# 初始化偏置
self.bias1 = np.random.randn(self.hidden_dim1)
self.bias2 = np.random.randn(self.hidden_dim2)
self.bias3 = np.random.randn(self.output_dim)
def forward(self, X):
# 前向传播
self.hidden_layer1 = sigmoid(np.dot(X, self.weights1) + self.bias1)
self.hidden_layer2 = sigmoid(np.dot(self.hidden_layer1, self.weights2) + self.bias2)
self.output_layer = sigmoid(np.dot(self.hidden_layer2, self.weights3) + self.bias3)
return self.output_layer
def train(self, X, y, learning_rate, epochs):
# 训练网络
for epoch in range(epochs):
# 前向传播
output = self.forward(X)
# 反向传播,计算梯度
error = y - output
delta_output = error * output * (1 - output)
delta_hidden2 = np.dot(delta_output, self.weights3.T) * self.hidden_layer2 * (1 - self.hidden_layer2)
delta_hidden1 = np.dot(delta_hidden2, self.weights2.T) * self.hidden_layer1 * (1 - self.hidden_layer1)
# 更新权重和偏置
self.weights3 += learning_rate * np.dot(self.hidden_layer2.T, delta_output)
self.weights2 += learning_rate * np.dot(self.hidden_layer1.T, delta_hidden2)
self.weights1 += learning_rate * np.dot(X.T, delta_hidden1)
self.bias3 += learning_rate * np.sum(delta_output, axis=0)
self.bias2 += learning_rate * np.sum(delta_hidden2, axis=0)
self.bias1 += learning_rate * np.sum(delta_hidden1, axis=0)
# 创建BP神经网络对象
nn = BPNeuralNetwork(input_dim=2, hidden_dim1=4, hidden_dim2=4, output_dim=1)
# 定义输入和目标输出
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])
# 训练神经网络
nn.train(X, y, learning_rate=0.1, epochs=10000)
# 预测
output = nn.forward(X)
```
相关推荐
![docx](https://img-home.csdnimg.cn/images/20210720083331.png)
![py](https://img-home.csdnimg.cn/images/20210720083646.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)