用tensorflow实现一个self-attention
时间: 2023-06-21 22:24:41 浏览: 216
tf-rnn-attention:用于文本分类任务的注意机制的Tensorflow实现
这里提供一个使用 TensorFlow 实现的自注意力机制(self-attention)的代码示例:
```
import tensorflow as tf
from tensorflow.keras.layers import Layer
class SelfAttention(Layer):
def __init__(self, hidden_size, num_heads, **kwargs):
super(SelfAttention, self).__init__(**kwargs)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.query = tf.keras.layers.Dense(hidden_size, use_bias=False)
self.key = tf.keras.layers.Dense(hidden_size, use_bias=False)
self.value = tf.keras.layers.Dense(hidden_size, use_bias=False)
self.softmax = tf.keras.layers.Softmax(axis=-1)
def call(self, x):
batch_size, seq_len, hidden_size = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
residual = x
# Project inputs to Q, K, V
q = self.query(x) # (batch_size, seq_len, hidden_size)
k = self.key(x) # (batch_size, seq_len, hidden_size)
v = self.value(x) # (batch_size, seq_len, hidden_size)
# Reshape Q, K, V for multi-head attention
q = tf.reshape(q, (batch_size, seq_len, self.num_heads, self.head_size))
k = tf.reshape(k, (batch_size, seq_len, self.num_heads, self.head_size))
v = tf.reshape(v, (batch_size, seq_len, self.num_heads, self.head_size))
q = tf.transpose(q, perm=[0, 2, 1, 3])
k = tf.transpose(k, perm=[0, 2, 1, 3])
v = tf.transpose(v, perm=[0, 2, 1, 3])
# Compute dot product attention
attention_scores = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(tf.cast(self.head_size, tf.float32)) # (batch_size, num_heads, seq_len, seq_len)
attention_probs = self.softmax(attention_scores)
# Apply attention weights to V
attention_output = tf.matmul(attention_probs, v) # (batch_size, num_heads, seq_len, head_size)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
attention_output = tf.reshape(attention_output, (batch_size, seq_len, self.hidden_size))
# Add residual connection
output = attention_output + residual
return output
```
这个自注意力机制的实现与 PyTorch 中的实现类似,也包含了查询(query)、键(key)和值(value)三个线性变换,以及一个 Softmax 操作。不同之处在于 TensorFlow 中需要明确指定多头注意力的具体维度变换方式,同时也需要注意一些维度的 reshape 和 transpose 操作。
阅读全文