easy_rec/python/layers/keras/transformer.py (150 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging import numpy as np import tensorflow as tf from tensorflow.python.keras.layers import Dense from tensorflow.python.keras.layers import Dropout from tensorflow.python.keras.layers import Embedding from tensorflow.python.keras.layers import Layer from easy_rec.python.layers.keras import MultiHeadAttention from easy_rec.python.layers.keras.layer_norm import LayerNormalization from easy_rec.python.layers.utils import Parameter from easy_rec.python.protos import seq_encoder_pb2 class TransformerBlock(Layer): """A transformer block combines multi-head attention and feed-forward networks with layer normalization and dropout. Purpose: Combines attention and feed-forward layers with residual connections and normalization. Components: Multi-head attention, feed-forward network, dropout, and layer normalization. Output: Enhanced representation after applying attention and feed-forward layers. """ def __init__(self, params, name='transformer_block', reuse=None, **kwargs): super(TransformerBlock, self).__init__(name=name, **kwargs) d_model = params.hidden_size num_heads = params.num_attention_heads mha_cfg = seq_encoder_pb2.MultiHeadAttention() mha_cfg.num_heads = num_heads mha_cfg.key_dim = d_model // num_heads mha_cfg.dropout = params.get_or_default('attention_probs_dropout_prob', 0.0) mha_cfg.return_attention_scores = False args = Parameter.make_from_pb(mha_cfg) self.mha = MultiHeadAttention(args, 'multi_head_attn') dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1) ffn_units = params.get_or_default('intermediate_size', d_model) ffn_act = params.get_or_default('hidden_act', 'relu') self.ffn_dense1 = Dense(ffn_units, activation=ffn_act) self.ffn_dense2 = Dense(d_model) if tf.__version__ >= '2.0': self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) else: self.layer_norm1 = LayerNormalization(epsilon=1e-6) self.layer_norm2 = LayerNormalization(epsilon=1e-6) self.dropout1 = Dropout(dropout_rate) self.dropout2 = Dropout(dropout_rate) def call(self, inputs, training=None, **kwargs): x, mask = inputs attn_output = self.mha([x, x, x], mask=mask, training=training) attn_output = self.dropout1(attn_output, training=training) out1 = self.layer_norm1(x + attn_output) ffn_mid = self.ffn_dense1(out1) ffn_output = self.ffn_dense2(ffn_mid) ffn_output = self.dropout2(ffn_output, training=training) out2 = self.layer_norm2(out1 + ffn_output) return out2 # Positional Encoding, https://www.tensorflow.org/text/tutorials/transformer def positional_encoding(length, depth): depth = depth / 2 positions = np.arange(length)[:, np.newaxis] # (seq, 1) depths = np.arange(depth)[np.newaxis, :] / depth # (1, depth) angle_rates = 1 / (10000**depths) # (1, depth) angle_rads = positions * angle_rates # (pos, depth) pos_encoding = np.concatenate( [np.sin(angle_rads), np.cos(angle_rads)], axis=-1) return tf.cast(pos_encoding, dtype=tf.float32) class PositionalEmbedding(Layer): def __init__(self, vocab_size, d_model, max_position, name='pos_embedding'): super(PositionalEmbedding, self).__init__(name=name) self.d_model = d_model self.embedding = Embedding(vocab_size, d_model) self.pos_encoding = positional_encoding(length=max_position, depth=d_model) def call(self, x, training=None): length = tf.shape(x)[1] x = self.embedding(x) # This factor sets the relative scale of the embedding and positional_encoding. x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) x = x + self.pos_encoding[tf.newaxis, :length, :] return x class TransformerEncoder(Layer): """The encoder consists of a stack of encoder layers. It converts the input sequence into a set of embeddings enriched with positional information. Purpose: Encodes the input sequence into a set of embeddings. Components: Embedding layer, positional encoding, and a stack of transformer blocks. Output: Encoded representation of the input sequence. """ def __init__(self, params, name='transformer_encoder', reuse=None, **kwargs): super(TransformerEncoder, self).__init__(name=name, **kwargs) d_model = params.hidden_size dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1) max_position = params.get_or_default('max_position_embeddings', 512) num_layers = params.get_or_default('num_hidden_layers', 1) vocab_size = params.vocab_size logging.info('vocab size of TransformerEncoder(%s) is %d', name, vocab_size) self.output_all = params.get_or_default('output_all_token_embeddings', True) self.pos_encoding = PositionalEmbedding(vocab_size, d_model, max_position) self.dropout = Dropout(dropout_rate) self.enc_layers = [ TransformerBlock(params, 'layer_%d' % i) for i in range(num_layers) ] self._vocab_size = vocab_size self._max_position = max_position @property def vocab_size(self): return self._vocab_size @property def max_position(self): return self._max_position def call(self, inputs, training=None, **kwargs): x, mask = inputs # `x` is token-IDs shape: (batch, seq_len) x = self.pos_encoding(x) # Shape `(batch_size, seq_len, d_model)`. x = self.dropout(x, training=training) for block in self.enc_layers: x = block([x, mask], training) # x Shape `(batch_size, seq_len, d_model)`. return x if self.output_all else x[:, 0, :] class TextEncoder(Layer): def __init__(self, params, name='text_encoder', reuse=None, **kwargs): super(TextEncoder, self).__init__(name=name, **kwargs) self.separator = params.get_or_default('separator', ' ') self.cls_token = '[CLS]' + self.separator self.sep_token = self.separator + '[SEP]' + self.separator params.transformer.output_all_token_embeddings = False trans_params = Parameter.make_from_pb(params.transformer) vocab_file = params.get_or_default('vocab_file', None) self.vocab = None self.default_token_id = params.get_or_default('default_token_id', 0) if vocab_file is not None: self.vocab = tf.feature_column.categorical_column_with_vocabulary_file( 'tokens', vocabulary_file=vocab_file, default_value=self.default_token_id) logging.info('vocab file of TextEncoder(%s) is %s', name, vocab_file) trans_params.vocab_size = self.vocab.vocabulary_size self.encoder = TransformerEncoder(trans_params, name='transformer') def call(self, inputs, training=None, **kwargs): if type(inputs) not in (tuple, list): inputs = [inputs] inputs = [tf.squeeze(text) for text in inputs] batch_size = tf.shape(inputs[0]) cls = tf.fill(batch_size, self.cls_token) sep = tf.fill(batch_size, self.sep_token) sentences = [cls] for sentence in inputs: sentences.append(sentence) sentences.append(sep) text = tf.strings.join(sentences) tokens = tf.strings.split(text, self.separator) if self.vocab is not None: features = {'tokens': tokens} token_ids = self.vocab._transform_feature(features) token_ids = tf.sparse.to_dense( token_ids, default_value=self.default_token_id, name='token_ids') length = tf.shape(token_ids)[-1] token_ids = tf.cond( tf.less_equal(length, self.encoder.max_position), lambda: token_ids, lambda: tf.slice(token_ids, [0, 0], [-1, self.encoder.max_position])) mask = tf.not_equal(token_ids, self.default_token_id, name='mask') else: tokens = tf.sparse.to_dense(tokens, default_value='') length = tf.shape(tokens)[-1] tokens = tf.cond( tf.less_equal(length, self.encoder.max_position), lambda: tokens, lambda: tf.slice(tokens, [0, 0], [-1, self.encoder.max_position])) token_ids = tf.string_to_hash_bucket_fast( tokens, self.encoder.vocab_size, name='token_ids') mask = tf.not_equal(tokens, '', name='mask') encoding = self.encoder([token_ids, mask], training=training) return encoding