LSTM的Python源代码
时间: 2024-05-18 10:09:52 浏览: 99
LSTM(长短时记忆网络)是一种常见的深度学习模型,用于处理序列数据。以下是一个简单的LSTM实现的Python源代码:
```
import numpy as np
class LSTM:
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.Wf = np.random.randn(hidden_size, input_size + hidden_size)
self.Wi = np.random.randn(hidden_size, input_size + hidden_size)
self.Wc = np.random.randn(hidden_size, input_size + hidden_size)
self.Wo = np.random.randn(hidden_size, input_size + hidden_size)
self.bf = np.zeros((hidden_size, 1))
self.bi = np.zeros((hidden_size, 1))
self.bc = np.zeros((hidden_size, 1))
self.bo = np.zeros((hidden_size, 1))
self.Wy = np.random.randn(output_size, hidden_size)
self.by = np.zeros((output_size, 1))
def sigmoid(self, x):
return 1 / (1 + np.exp(-x))
def tanh(self, x):
return np.tanh(x)
def forward(self, x):
T = x.shape
self.h = np.zeros((self.hidden_size, T))
self.c = np.zeros((self.hidden_size, T))
self.f = np.zeros((self.hidden_size, T))
self.i = np.zeros((self.hidden_size, T))
self.o = np.zeros((self.hidden_size, T))
self.y = np.zeros((self.output_size, T))
for t in range(T):
xt = x[:,t].reshape(-1, 1)
ft = self.sigmoid(np.dot(self.Wf, np.vstack((self.h[:,t-1], xt))) + self.bf)
it = self.sigmoid(np.dot(self.Wi, np.vstack((self.h[:,t-1], xt))) + self.bi)
cct = self.tanh(np.dot(self.Wc, np.vstack((self.h[:,t-1], xt))) + self.bc)
ot = self.sigmoid(np.dot(self.Wo, np.vstack((self.h[:,t-1], xt))) + self.bo)
self.f[:,t] = ft[:,0]
self.i[:,t] = it[:,0]
self.c[:,t] = ft[:,0] * self.c[:,t-1] + it[:,0] * cct[:,0]
self.o[:,t] = ot[:,0]
self.h[:,t] = ot[:,0] * self.tanh(self.c[:,t])
self.y[:,t] = np.dot(self.Wy, self.h[:,t]) + self.by
return self.y
def backward(self, x, y_true, learning_rate=0.1):
T = x.shape
dWy = np.zeros_like(self.Wy)
dby = np.zeros_like(self.by)
dh_next = np.zeros_like(self.h[:,0]).reshape(-1, 1)
dc_next = np.zeros_like(self.c[:,0]).reshape(-1, 1)
dWf = np.zeros_like(self.Wf)
dWi = np.zeros_like(self.Wi)
dWc = np.zeros_like(self.Wc)
dWo = np.zeros_like(self.Wo)
dbf = np.zeros_like(self.bf)
dbi = np.zeros_like(self.bi)
dbc = np.zeros_like(self.bc)
dbo = np.zeros_like(self.bo)
for t in reversed(range(T)):
yt = y_true[:,t].reshape(-1, 1)
dy = (self.y[:,t].reshape(-1, 1) - yt)
dh = np.dot(self.Wy.T, dy) + dh_next
do = dh * self.tanh(self.c[:,t]) * self.o[:,t] * (1 - self.o[:,t])
dc_bar = dh * self.o[:,t] * (1 - self.tanh(self.c[:,t])**2) + dc_next
dc_next = dc_bar * self.f[:,t]
df = dc_bar * self.c[:,t-1] * self.f[:,t] * (1 - self.f[:,t])
di = dc_bar * self.cct[:,0] * self.i[:,t] * (1 - self.i[:,t])
dcct = dc_bar * self.i[:,t] * (1 - self.cct[:,0]**2)
dWf += df @ np.vstack((self.h[:,t-1], x[:,t])).T
dWi += di @ np.vstack((self.h[:,t-1], x[:,t])).T
dWc += dcct @ np.vstack((self.h[:,t-1], x[:,t])).T
dWo += do @ np.vstack((self.h[:,t-1], x[:,t])).T
dbf += df
dbi += di
dbc += dcct
dbo += do
dh_next = (np.dot(self.Wf.T, df) +
np.dot(self.Wi.T, di) +
np.dot(self.Wc.T, dcct) +
np.dot(self.Wo.T, do))
for dparam in [dWf, dWi, dWc, dWo, dbf, dbi, dbc, dbo]:
np.clip(dparam, -5, 5, out=dparam)
for param, dparam in zip([self.Wf, self.Wi, self.Wc, self.Wo,
self.bf, self.bi, self.bc, self.bo,
self.Wy, self.by],
[dWf, dWi, dWc, dWo,
dbf, dbi, dbc, dbo,
dWy, dby]):
param -= learning_rate * dparam
def train(self, X_train, Y_train,
X_valid=None,
Y_valid=None,
epochs=100,
learning_rate=0.1):
if X_valid is not None:
is_valid=True
else:
is_valid=False
for i in range(epochs):
loss_train = 0
for j in range(len(X_train)):
x_train = X_train[j]
y_train = Y_train[j]
y_pred_train = lstm.forward(x_train)
lstm.backward(x_train, y_train)
loss_train += ((y_pred_train - y_train)**2).mean()
if is_valid:
loss_valid = 0
for k in range(len(X_valid)):
x_valid = X_valid[k]
y_valid = Y_valid[k]
y_pred_valid = lstm.forward(x_valid)
loss_valid += ((y_pred_valid - y_valid)**2).mean()
print("Epoch {:3d}: Train Loss {:.4f}, Valid Loss {:.4f}".format(
i+1,
loss_train / len(X_train),
loss_valid / len(X_valid)))
else:
print("Epoch {:3d}: Train Loss {:.4f}".format(
i+1,
loss_train / len(X_train)))
if __name__ == '__main__':
X_train = [np.random.randn(10).reshape(-1, 1) for _ in range(100)]
Y_train = [np.random.randn(5).reshape(-1, 1) for _ in range(100)]
X_valid = [np.random.randn(10).reshape(-1, 1) for _ in range(10)]
Y_valid = [np.random.randn(5).reshape(-1, 1) for _ in range(10)]
lstm = LSTM(input_size=10,
hidden_size=32,
output_size=5)
lstm.train(X_train=X_train,
Y_train=Y_train,
X_valid=X_valid,
Y_valid=Y_valid,
epochs=50,
learning_rate=0.01)
```
上面的代码实现了一个基本的LSTM模型,包括前向传播和反向传播过程。如果你想使用LSTM模型,可以根据自己的需求对这个代码进行修改和扩展。同时,需要注意的是,这只是一个简单的实现示例,实际使用中可能需要更复杂的结构和技巧来提高模型的性能。
阅读全文