多层感知机python实现
时间: 2023-07-07 15:17:35 浏览: 61
python实现多层感知器
好的,这是一个简单的多层感知机的Python实现,使用NumPy库:
```python
import numpy as np
class MLP:
def __init__(self, num_inputs=3, num_hidden=[3, 5], num_outputs=2):
self.num_inputs = num_inputs
self.num_hidden = num_hidden
self.num_outputs = num_outputs
layers = [self.num_inputs] + self.num_hidden + [self.num_outputs]
# Initialize weights
weights = []
for i in range(len(layers)-1):
w = np.random.rand(layers[i], layers[i+1])
weights.append(w)
self.weights = weights
def sigmoid(self, x):
return 1/(1+np.exp(-x))
def feed_forward(self, x):
# Calculate output from input
a = x
for w in self.weights:
z = np.dot(a, w)
a = self.sigmoid(z)
return a
def backpropagation(self, x, y, learning_rate=0.1):
# Feed forward
a = x
activations = [a]
zs = []
for w in self.weights:
z = np.dot(a, w)
zs.append(z)
a = self.sigmoid(z)
activations.append(a)
# Calculate error
error = (activations[-1] - y) * activations[-1] * (1 - activations[-1])
# Backpropagate error
deltas = [error]
for i in range(len(self.weights)-1, 0, -1):
delta = np.dot(deltas[-1], self.weights[i].T) * activations[i] * (1 - activations[i])
deltas.append(delta)
deltas.reverse()
# Update weights
for i in range(len(self.weights)):
self.weights[i] -= learning_rate * np.dot(activations[i].reshape(-1,1), deltas[i].reshape(1,-1))
def train(self, X, y, learning_rate=0.1, epochs=100):
for epoch in range(epochs):
for i in range(len(X)):
self.backpropagation(X[i], y[i], learning_rate)
def predict(self, X):
y_pred = []
for x in X:
y = self.feed_forward(x)
y_pred.append(y)
return np.array(y_pred)
```
这个多层感知机使用sigmoid函数作为激活函数,使用反向传播算法进行训练。你可以按照以下步骤使用它:
1. 创建一个MLP对象,指定输入层,隐藏层和输出层的大小。
2. 使用train方法传入训练数据和标签,进行模型训练。
3. 使用predict方法传入测试数据,得到预测结果。
例如,如果你想要训练一个多层感知机来预测一个二进制加法的结果,可以使用以下代码:
```python
X = np.array([[0,0], [0,1], [1,0], [1,1]])
y = np.array([[0], [1], [1], [0]])
mlp = MLP(num_inputs=2, num_hidden=[3], num_outputs=1)
mlp.train(X, y, learning_rate=0.1, epochs=1000)
y_pred = mlp.predict(X)
print(y_pred)
```
这将输出预测结果,例如:
```
[[0.03054169]
[0.96964232]
[0.96964232]
[0.03054169]]
```
阅读全文