lstm的class
时间: 2023-11-11 19:03:01 浏览: 48
以下是一个简单的LSTM类的实现,包括前向传播和反向传播:
```python
import numpy as np
class LSTM:
def __init__(self, input_size, hidden_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.Wf = np.random.randn(input_size + hidden_size, hidden_size)
self.Wi = np.random.randn(input_size + hidden_size, hidden_size)
self.Wo = np.random.randn(input_size + hidden_size, hidden_size)
self.Wc = np.random.randn(input_size + hidden_size, hidden_size)
self.bf = np.zeros((1, hidden_size))
self.bi = np.zeros((1, hidden_size))
self.bo = np.zeros((1, hidden_size))
self.bc = np.zeros((1, hidden_size))
self.cache = None
def sigmoid(self, x):
return 1 / (1 + np.exp(-x))
def tanh(self, x):
return np.tanh(x)
def forward(self, x, h_prev, c_prev):
concat = np.hstack((x, h_prev))
f = self.sigmoid(np.dot(concat, self.Wf) + self.bf)
i = self.sigmoid(np.dot(concat, self.Wi) + self.bi)
o = self.sigmoid(np.dot(concat, self.Wo) + self.bo)
c_tilde = self.tanh(np.dot(concat, self.Wc) + self.bc)
c_next = f * c_prev + i * c_tilde
h_next = o * self.tanh(c_next)
cache = (concat, f, i, o, c_tilde, c_prev, h_next)
return h_next, c_next, cache
def backward(self, dh_next, dc_next, cache):
concat, f, i, o, c_tilde, c_prev, h_next = cache
tanh_c_next = np.tanh(c_next)
do = dh_next * tanh_c_next
dc_next += dh_next * o * (1 - tanh_c_next ** 2)
df = dc_next * c_prev
di = dc_next * c_tilde
dc_tilde = dc_next * i
dconcat = np.zeros_like(concat)
dWf, dWi, dWo, dWc = np.zeros_like(self.Wf), np.zeros_like(self.Wi), np.zeros_like(self.Wo), np.zeros_like(self.Wc)
dbf, dbi, dbo, dbc = np.zeros_like(self.bf), np.zeros_like(self.bi), np.zeros_like(self.bo), np.zeros_like(self.bc)
dconcat += np.dot(df * f * (1 - f), self.Wf.T)
dWf += np.dot(concat.T, df * f * (1 - f))
dbf += np.sum(df * f * (1 - f), axis=0, keepdims=True)
dconcat += np.dot(di * i * (1 - i), self.Wi.T)
dWi += np.dot(concat.T, di * i * (1 - i))
dbi += np.sum(di * i * (1 - i), axis=0, keepdims=True)
dconcat += np.dot(do * o * (1 - o), self.Wo.T)
dWo += np.dot(concat.T, do * o * (1 - o))
dbo += np.sum(do * o * (1 - o), axis=0, keepdims=True)
dconcat += np.dot(dc_tilde * (1 - c_tilde ** 2), self.Wc.T)
dWc += np.dot(concat.T, dc_tilde * (1 - c_tilde ** 2))
dbc += np.sum(dc_tilde * (1 - c_tilde ** 2), axis=0, keepdims=True)
dx = dconcat[:, :self.input_size]
dh_prev = dconcat[:, self.input_size:]
return dx, dh_prev, dc_next, dWf, dWi, dWo, dWc, dbf, dbi, dbo, dbc
```