Seq2Seq模型代码实现
时间: 2023-10-07 13:04:49 浏览: 96
seq2seq模型
Seq2Seq模型是指序列到序列模型,通常用于机器翻译、语音识别、摘要生成、对话生成等任务。其中,编码器将输入序列编码为一个向量,解码器则将该向量解码为输出序列。下面是一个简单的Seq2Seq模型代码实现,使用了Python和TensorFlow库:
```python
import tensorflow as tf
# 定义输入输出序列最大长度和词表大小
MAX_LENGTH = 100
VOCAB_SIZE = 10000
# 定义编码器
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units):
super(Encoder, self).__init__()
self.enc_units = enc_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(enc_units, return_sequences=True, return_state=True)
def call(self, x, hidden):
x = self.embedding(x)
output, state = self.gru(x, initial_state = hidden)
return output, state
def initialize_hidden_state(self, batch_size):
return tf.zeros((batch_size, self.enc_units))
# 定义注意力层
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self, units):
super(BahdanauAttention, self).__init__()
self.W1 = tf.keras.layers.Dense(units)
self.W2 = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self, query, values):
query_with_time_axis = tf.expand_dims(query, 1)
score = self.V(tf.nn.tanh(
self.W1(query_with_time_axis) + self.W2(values)))
attention_weights = tf.nn.softmax(score, axis=1)
context_vector = attention_weights * values
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector, attention_weights
# 定义解码器
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units):
super(Decoder, self).__init__()
self.dec_units = dec_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(dec_units, return_sequences=True, return_state=True)
self.fc = tf.keras.layers.Dense(vocab_size)
self.attention = BahdanauAttention(dec_units)
def call(self, x, hidden, enc_output):
context_vector, attention_weights = self.attention(hidden, enc_output)
x = self.embedding(x)
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
output, state = self.gru(x)
output = tf.reshape(output, (-1, output.shape[2]))
x = self.fc(output)
return x, state, attention_weights
# 定义损失函数和优化器
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def loss_function(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_mean(loss_)
# 定义模型
class Seq2Seq(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units, dec_units, batch_size):
super(Seq2Seq, self).__init__()
self.batch_size = batch_size
self.encoder = Encoder(vocab_size, embedding_dim, enc_units)
self.decoder = Decoder(vocab_size, embedding_dim, dec_units)
def call(self, inputs):
inp, targ = inputs
enc_hidden = self.encoder.initialize_hidden_state(self.batch_size)
enc_output, enc_hidden = self.encoder(inp, enc_hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ[0]] * self.batch_size, 1)
predictions = []
for t in range(1, targ.shape[1]):
predictions_batch, dec_hidden, _ = self.decoder(dec_input, dec_hidden, enc_output)
predictions.append(predictions_batch)
dec_input = tf.expand_dims(targ[:, t], 1)
return tf.stack(predictions, axis=1)
# 训练模型
model = Seq2Seq(VOCAB_SIZE, 256, 1024, 1024, 64)
def train_step(inp, targ):
loss = 0
with tf.GradientTape() as tape:
predictions = model([inp, targ[:,:-1]])
loss = loss_function(targ[:,1:], predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
# 测试模型
def evaluate(sentence):
attention_plot = np.zeros((max_length_targ, max_length_inp))
sentence = preprocess_sentence(sentence)
inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')
inputs = tf.convert_to_tensor(inputs)
result = ''
hidden = [tf.zeros((1, units))]
enc_out, enc_hidden = encoder(inputs, hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)
for t in range(max_length_targ):
predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)
attention_weights = tf.reshape(attention_weights, (-1,))
attention_plot[t] = attention_weights.numpy()
predicted_id = tf.argmax(predictions[0]).numpy()
result += targ_lang.index_word[predicted_id] + ' '
if targ_lang.index_word[predicted_id] == '<end>':
return result, sentence, attention_plot
dec_input = tf.expand_dims([predicted_id], 0)
return result, sentence, attention_plot
```
阅读全文