tf 实现lstm + self-attention
时间: 2023-11-21 12:07:02 浏览: 112
以下是使用TensorFlow实现LSTM+Self-Attention的代码示例:
```python
import tensorflow as tf
class LSTM_Attention(object):
def __init__(self, hidden_size, attention_size):
self.hidden_size = hidden_size
self.attention_size = attention_size
self.inputs = tf.placeholder(tf.float32, [None, None, self.hidden_size], name='inputs')
self.targets = tf.placeholder(tf.float32, [None, self.hidden_size], name='targets')
self.seq_len = tf.placeholder(tf.int32, [None], name='seq_len')
self.learning_rate = tf.placeholder(tf.float32, name='learning_rate')
self.global_step = tf.Variable(0, trainable=False)
with tf.variable_scope('lstm'):
lstm_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_size)
outputs, _ = tf.nn.dynamic_rnn(lstm_cell, self.inputs, sequence_length=self.seq_len, dtype=tf.float32)
with tf.variable_scope('attention'):
attention_w = tf.Variable(tf.truncated_normal([self.hidden_size, self.attention_size], stddev=0.1), name='attention_w')
attention_b = tf.Variable(tf.constant(0.1, shape=[self.attention_size]), name='attention_b')
u = tf.Variable(tf.truncated_normal([self.attention_size], stddev=0.1), name='attention_u')
v = tf.tanh(tf.tensordot(outputs, attention_w, axes=1) + attention_b)
vu = tf.tensordot(v, u, axes=1, name='vu')
alphas = tf.nn.softmax(vu, name='alphas')
output = tf.reduce_sum(outputs * tf.expand_dims(alphas, -1), 1)
with tf.variable_scope('output'):
w = tf.Variable(tf.truncated_normal([self.hidden_size, self.hidden_size], stddev=0.1), name='w')
b = tf.Variable(tf.constant(0.1, shape=[self.hidden_size]), name='b')
self.logits = tf.matmul(output, w) + b
self.prediction = tf.nn.tanh(self.logits)
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.square(self.targets - self.prediction))
optimizer = tf.train.AdamOptimizer(self.learning_rate)
gradients, variables = zip(*optimizer.compute_gradients(self.loss))
gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
self.train_op = optimizer.apply_gradients(zip(gradients, variables), global_step=self.global_step)
def train(self, sess, inputs, targets, seq_len, learning_rate):
feed_dict = {self.inputs: inputs, self.targets: targets, self.seq_len: seq_len, self.learning_rate: learning_rate}
_, loss, step = sess.run([self.train_op, self.loss, self.global_step], feed_dict=feed_dict)
return loss, step
def predict(self, sess, inputs, seq_len):
feed_dict = {self.inputs: inputs, self.seq_len: seq_len}
prediction = sess.run(self.prediction, feed_dict=feed_dict)
return prediction
```
阅读全文