numpy实现mlp的反向传播,其中损失函数使用交叉熵和L2正则化,权重矩阵为增广矩阵,第1个激活函数选择Relu,第2个激活函数选择Softmax
时间: 2024-05-09 14:19:03 浏览: 193
以下是numpy实现mlp的反向传播代码:
```
import numpy as np
def relu(x):
return np.maximum(x, 0)
def softmax(x):
exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
return exp_x / np.sum(exp_x, axis=1, keepdims=True)
def cross_entropy(y_pred, y_true):
return -np.sum(y_true * np.log(y_pred)) / len(y_pred)
def l2_regularization(lmbda, w1, w2):
return lmbda / 2 * (np.sum(np.square(w1)) + np.sum(np.square(w2)))
def mlp_backward(X, y, w1, b1, w2, b2, learning_rate, lmbda):
# forward pass
z1 = np.dot(X, w1) + b1
a1 = relu(z1)
z2 = np.dot(a1, w2) + b2
a2 = softmax(z2)
# compute gradients
delta2 = a2 - y
delta1 = np.dot(delta2, w2.T) * (z1 > 0)
dw2 = np.dot(a1.T, delta2)
db2 = np.sum(delta2, axis=0, keepdims=True)
dw1 = np.dot(X.T, delta1)
db1 = np.sum(delta1, axis=0)
# add L2 regularization
dw2 += lmbda * w2
dw1 += lmbda * w1
# update weights and biases
w2 -= learning_rate * dw2
b2 -= learning_rate * db2
w1 -= learning_rate * dw1
b1 -= learning_rate * db1
# compute loss
loss = cross_entropy(a2, y) + l2_regularization(lmbda, w1, w2)
return w1, b1, w2, b2, loss
```
其中,X为输入数据,y为标签,w1和w2为权重矩阵,b1和b2为偏置向量,learning_rate为学习率,lmbda为L2正则化系数。在forward pass中,首先计算第一层的加权和z1和激活函数a1,然后计算第二层的加权和z2和激活函数a2。在backward pass中,首先计算输出层的误差项delta2和隐藏层的误差项delta1,然后根据误差项计算梯度dw2、db2、dw1和db1。最后,加上L2正则化项,更新权重矩阵和偏置向量,并计算损失函数。
阅读全文
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![](https://img-home.csdnimg.cn/images/20250102104920.png)
![-](https://img-home.csdnimg.cn/images/20241231044955.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231044955.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)