完善以下代码:function [W1, W2] = BackpropCE(W1, W2, X, D) % 交叉熵损失下标准BP算法(SGD) % 补充完整:注意delta 和 delta1如何变化? % CE损失下标准BP算法 alpha = 0.9; N = 4; for k = 1:N x = X(k, :)'; % x = a column vector d = D(k); v1 = W1*x; y1 = Sigmoid(v1); v = W2.*y1; y = Sigmoid(v); % 填空:delta=? % 填空:delta1=? % 填空:dW1=? W1 = W1 + dW1; dW2 = alpha*delta*y1'; W2 = W2 + dW2; end end
时间: 2023-06-13 08:05:48 浏览: 68
function [W1, W2] = BackpropCE(W1, W2, X, D)
% 交叉熵损失下标准BP算法(SGD)
% 补充完整:注意delta 和 delta1如何变化?
% CE损失下标准BP算法
alpha = 0.9;
N = 4;
for k = 1:N
x = X(k, :)'; % x = a column vector
d = D(k);
v1 = W1*x;
y1 = Sigmoid(v1);
v = W2*y1;
y = Sigmoid(v);
delta = (y - d) .* y .* (1-y);
delta1 = W2' * delta .* y1 .* (1-y1);
dW1 = alpha * delta1 * x';
W1 = W1 - dW1;
dW2 = alpha * delta * y1';
W2 = W2 - dW2;
end
end
相关问题
完善以下程序:function [W1, W2] = BackpropCE(W1, W2, X, D) % 交叉熵损失下标准BP算法(SGD) % 补充完整:注意delta 和 delta1如何变化? % CE损失下标准BP算法 alpha = 0.9; N = 4; for k = 1:N x = X(k, :)'; % x = a column vector d = D(k); v1 = W1*x; y1 = Sigmoid(v1);
v = W2*y1;
y = Sigmoid(v);
% 计算输出层的delta
delta = -(d - y) .* y .* (1 - y);
% 计算隐藏层的delta
delta1 = y1 .* (1 - y1) .* (W2' * delta);
% 更新W2
W2 = W2 - alpha * delta * y1';
% 更新W1
W1 = W1 - alpha * delta1 * x';
end
优化代码 def GetAlgType(self, AlgType): if AlgType == "SGD_SM1": AlgType = self.AlgType.SGD_SM1 elif AlgType == "SGD_SM4": AlgType = self.AlgType.SGD_SM4 elif AlgType == "SGD_DES": AlgType = self.AlgType.SGD_DES elif AlgType == "SGD_2DES": AlgType = self.AlgType.SGD_2DES elif AlgType == "SGD_3DES": AlgType = self.AlgType.SGD_SM4 elif AlgType == "SGD_AES": AlgType = self.AlgType.SGD_AES elif AlgType == "SGD_AES192": AlgType = self.AlgType.SGD_AES192 elif AlgType == "SGD_AES256": AlgType = self.AlgType.SGD_AES256 return AlgType
可以优化代码,将多个if-elif语句改为使用字典来映射AlgType值。这样可以提高代码的可读性和性能。修改后的代码如下:
def GetAlgType(self, AlgType):
alg_mapping = {
"SGD_SM1": self.AlgType.SGD_SM1,
"SGD_SM4": self.AlgType.SGD_SM4,
"SGD_DES": self.AlgType.SGD_DES,
"SGD_2DES": self.AlgType.SGD_2DES,
"SGD_3DES": self.AlgType.SGD_SM4,
"SGD_AES": self.AlgType.SGD_AES,
"SGD_AES192": self.AlgType.SGD_AES192,
"SGD_AES256": self.AlgType.SGD_AES256
}
return alg_mapping.get(AlgType, AlgType)
这样,通过查找字典中的映射关系,可以直接返回对应的AlgType值,如果找不到映射关系则返回原始的AlgType值。