gistlib
main.pyimport tensorflow as tf class SelfAttention(tf.keras.layers.Layer): def __init__(self, num_heads, d_model): super(SelfAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model assert d_model % self.num_heads == 0 self.depth = d_model // self.num_heads self.Wq = tf.keras.layers.Dense(d_model) self.Wk = tf.keras.layers.Dense(d_model) self.Wv = tf.keras.layers.Dense(d_model) def split_heads(self, x, batch_size): x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) return tf.transpose(x, perm=[0, 2, 1, 3]) def call(self, inputs): q = self.Wq(inputs) k = self.Wk(inputs) v = self.Wv(inputs) batch_size = tf.shape(q)[0] q = self.split_heads(q, batch_size) k = self.split_heads(k, batch_size) v = self.split_heads(v, batch_size) output, self_weights = self.scaled_dot_product(q, k, v) return output, self_weights def scaled_dot_product(self, q, k, v): matmul_qk = tf.matmul(q, k, transpose_b=True) dk = tf.cast(tf.shape(k)[-1], tf.float32) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) output = tf.matmul(attention_weights, v) return output, attention_weights 1408 chars43 lines
import tensorflow as tf class SelfAttention(tf.keras.layers.Layer): def __init__(self, num_heads, d_model): super(SelfAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model assert d_model % self.num_heads == 0 self.depth = d_model // self.num_heads self.Wq = tf.keras.layers.Dense(d_model) self.Wk = tf.keras.layers.Dense(d_model) self.Wv = tf.keras.layers.Dense(d_model) def split_heads(self, x, batch_size): x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) return tf.transpose(x, perm=[0, 2, 1, 3]) def call(self, inputs): q = self.Wq(inputs) k = self.Wk(inputs) v = self.Wv(inputs) batch_size = tf.shape(q)[0] q = self.split_heads(q, batch_size) k = self.split_heads(k, batch_size) v = self.split_heads(v, batch_size) output, self_weights = self.scaled_dot_product(q, k, v) return output, self_weights def scaled_dot_product(self, q, k, v): matmul_qk = tf.matmul(q, k, transpose_b=True) dk = tf.cast(tf.shape(k)[-1], tf.float32) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) output = tf.matmul(attention_weights, v) return output, attention_weights
This code snippet provides a simple implementation of a self-attention mechanism in Python using TensorFlow.
gistlibby LogSnag