bp神经网络回归python代码
时间: 2023-07-08 22:29:15 浏览: 90
BP神经网络python代码
以下是一个简单的 BP 神经网络回归的 Python 代码:
```python
import numpy as np
class NeuralNetwork:
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.weights_hidden = np.random.randn(self.input_size, self.hidden_size)
self.bias_hidden = np.random.randn(self.hidden_size)
self.weights_output = np.random.randn(self.hidden_size, self.output_size)
self.bias_output = np.random.randn(self.output_size)
def sigmoid(self, x):
return 1 / (1 + np.exp(-x))
def sigmoid_derivative(self, x):
return x * (1 - x)
def forward(self, x):
self.hidden_layer = self.sigmoid(np.dot(x, self.weights_hidden) + self.bias_hidden)
self.output_layer = self.sigmoid(np.dot(self.hidden_layer, self.weights_output) + self.bias_output)
return self.output_layer
def backward(self, x, y, output):
self.output_error = y - output
self.output_delta = self.output_error * self.sigmoid_derivative(output)
self.hidden_error = np.dot(self.output_delta, self.weights_output.T)
self.hidden_delta = self.hidden_error * self.sigmoid_derivative(self.hidden_layer)
self.weights_output += np.dot(self.hidden_layer.T, self.output_delta)
self.bias_output += np.sum(self.output_delta, axis=0)
self.weights_hidden += np.dot(x.T, self.hidden_delta)
self.bias_hidden += np.sum(self.hidden_delta, axis=0)
def train(self, x, y, epochs):
for i in range(epochs):
output = self.forward(x)
self.backward(x, y, output)
def predict(self, x):
return self.forward(x)
```
其中,`NeuralNetwork` 类包含三个参数:输入层大小、隐藏层大小和输出层大小。在初始化函数中,我们使用正态分布随机初始化隐藏层和输出层的权重和偏差。`sigmoid` 函数用于激活隐藏层和输出层,`sigmoid_derivative` 函数用于计算激活函数的导数。
`forward` 函数将输入 `x` 传递到网络中,并返回输出层的预测值。`backward` 函数用于计算误差和更新权重和偏差,从而使网络的预测更接近目标输出 `y`。`train` 函数用于训练网络,使用反向传播算法来更新权重和偏差。`predict` 函数用于对新数据进行预测。
使用时,可以将数据集 `X` 和目标输出 `y` 传递给 `train` 函数,以训练网络。例如:
```python
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])
nn = NeuralNetwork(2, 3, 1)
nn.train(X, y, epochs=10000)
print(nn.predict(X))
```
这里使用了一个简单的 XOR 数据集来训练网络。训练完成后,可以使用 `predict` 函数来对新数据进行预测。
阅读全文