easy_rec/python/layers/uniter.py (259 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import tensorflow as tf from easy_rec.python.layers import dnn from easy_rec.python.layers import multihead_cross_attention from easy_rec.python.utils.activation import get_activation from easy_rec.python.utils.shape_utils import get_shape_list if tf.__version__ >= '2.0': tf = tf.compat.v1 class Uniter(object): """UNITER: UNiversal Image-TExt Representation Learning. See the original paper: https://arxiv.org/abs/1909.11740 """ def __init__(self, model_config, feature_configs, features, uniter_config, input_layer): self._model_config = uniter_config tower_num = 0 self._img_features = None if input_layer.has_group('image'): self._img_features, _ = input_layer(features, 'image') tower_num += 1 self._general_features = None if input_layer.has_group('general'): self._general_features, _ = input_layer(features, 'general') tower_num += 1 self._txt_seq_features = None if input_layer.has_group('text'): self._txt_seq_features, _, _ = input_layer( features, 'text', is_combine=False) tower_num += 1 self._use_token_type = True if tower_num > 1 else False self._other_features = None if input_layer.has_group('other'): # e.g. statistical feature self._other_features, _ = input_layer(features, 'other') tower_num += 1 assert tower_num > 0, 'there must be one of the feature groups: [image, text, general, other]' self._general_feature_num = 0 self._txt_feature_num, self._img_feature_num = 0, 0 general_feature_names = set() img_feature_names, txt_feature_names = set(), set() for fea_group in model_config.feature_groups: if fea_group.group_name == 'general': self._general_feature_num = len(fea_group.feature_names) general_feature_names = set(fea_group.feature_names) assert self._general_feature_num == len(general_feature_names), ( 'there are duplicate features in `general` feature group') elif fea_group.group_name == 'image': self._img_feature_num = len(fea_group.feature_names) img_feature_names = set(fea_group.feature_names) assert self._img_feature_num == len(img_feature_names), ( 'there are duplicate features in `image` feature group') elif fea_group.group_name == 'text': self._txt_feature_num = len(fea_group.feature_names) txt_feature_names = set(fea_group.feature_names) assert self._txt_feature_num == len(txt_feature_names), ( 'there are duplicate features in `text` feature group') if self._txt_feature_num > 1 or self._img_feature_num > 1: self._use_token_type = True self._token_type_vocab_size = self._txt_feature_num if self._img_feature_num > 0: self._token_type_vocab_size += 1 if self._general_feature_num > 0: self._token_type_vocab_size += 1 max_seq_len = 0 txt_fea_emb_dim_list = [] general_emb_dim_list = [] img_fea_emb_dim_list = [] for feature_config in feature_configs: fea_name = feature_config.input_names[0] if feature_config.HasField('feature_name'): fea_name = feature_config.feature_name if fea_name in img_feature_names: img_fea_emb_dim_list.append(feature_config.raw_input_dim) if fea_name in general_feature_names: general_emb_dim_list.append(feature_config.embedding_dim) if fea_name in txt_feature_names: txt_fea_emb_dim_list.append(feature_config.embedding_dim) if feature_config.HasField('max_seq_len'): assert feature_config.max_seq_len > 0, ( 'feature config `max_seq_len` must be greater than 0 for feature: ' + fea_name) if feature_config.max_seq_len > max_seq_len: max_seq_len = feature_config.max_seq_len unique_dim_num = len(set(txt_fea_emb_dim_list)) assert unique_dim_num <= 1 and len( txt_fea_emb_dim_list ) == self._txt_feature_num, ( 'Uniter requires that all `text` feature dimensions must be consistent.' ) unique_dim_num = len(set(img_fea_emb_dim_list)) assert unique_dim_num <= 1 and len( img_fea_emb_dim_list ) == self._img_feature_num, ( 'Uniter requires that all `image` feature dimensions must be consistent.' ) unique_dim_num = len(set(general_emb_dim_list)) assert unique_dim_num <= 1 and len( general_emb_dim_list ) == self._general_feature_num, ( 'Uniter requires that all `general` feature dimensions must be consistent.' ) if self._txt_feature_num > 0 and uniter_config.use_position_embeddings: assert uniter_config.max_position_embeddings > 0, ( 'model config `max_position_embeddings` must be greater than 0. ') assert uniter_config.max_position_embeddings >= max_seq_len, ( 'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config ' '`max_seq_len`, which is %d' % max_seq_len) self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0 self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0 self._general_emb_size = general_emb_dim_list[ 0] if general_emb_dim_list else 0 if self._img_features is not None: assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`' def text_embeddings(self, token_type_id): all_txt_features = [] input_masks = [] hidden_size = self._model_config.hidden_size if self._general_features is not None: general_features = self._general_features if self._general_emb_size != hidden_size: # Run a linear projection of `hidden_size` general_features = tf.reshape( general_features, shape=[-1, self._general_emb_size]) general_features = tf.layers.dense( general_features, hidden_size, name='txt_projection') general_features = tf.reshape( general_features, shape=[-1, self._general_feature_num, hidden_size]) batch_size = tf.shape(general_features)[0] general_features = multihead_cross_attention.embedding_postprocessor( general_features, use_token_type=self._use_token_type, token_type_ids=tf.ones( shape=tf.stack([batch_size, self._general_feature_num]), dtype=tf.int32) * token_type_id, token_type_vocab_size=self._token_type_vocab_size, reuse_token_type=tf.AUTO_REUSE, use_position_embeddings=False, dropout_prob=self._model_config.hidden_dropout_prob) all_txt_features.append(general_features) mask = tf.ones( shape=tf.stack([batch_size, self._general_feature_num]), dtype=tf.int32) input_masks.append(mask) if self._txt_seq_features is not None: def dynamic_mask(x, max_len): ones = tf.ones(shape=tf.stack([x]), dtype=tf.int32) zeros = tf.zeros(shape=tf.stack([max_len - x]), dtype=tf.int32) return tf.concat([ones, zeros], axis=0) token_type_id += len(all_txt_features) for i, (seq_fea, seq_len) in enumerate(self._txt_seq_features): batch_size, max_seq_len, emb_size = get_shape_list(seq_fea, 3) if emb_size != hidden_size: seq_fea = tf.reshape(seq_fea, shape=[-1, emb_size]) seq_fea = tf.layers.dense( seq_fea, hidden_size, name='txt_seq_projection_%d' % i) seq_fea = tf.reshape(seq_fea, shape=[-1, max_seq_len, hidden_size]) seq_fea = multihead_cross_attention.embedding_postprocessor( seq_fea, use_token_type=self._use_token_type, token_type_ids=tf.ones( shape=tf.stack([batch_size, max_seq_len]), dtype=tf.int32) * (i + token_type_id), token_type_vocab_size=self._token_type_vocab_size, reuse_token_type=tf.AUTO_REUSE, use_position_embeddings=self._model_config.use_position_embeddings, max_position_embeddings=self._model_config.max_position_embeddings, position_embedding_name='txt_position_embeddings_%d' % i, dropout_prob=self._model_config.hidden_dropout_prob) all_txt_features.append(seq_fea) input_mask = tf.map_fn( fn=lambda t: dynamic_mask(t, max_seq_len), elems=tf.to_int32(seq_len)) input_masks.append(input_mask) return all_txt_features, input_masks def image_embeddings(self): if self._img_features is None: return None hidden_size = self._model_config.hidden_size image_features = self._img_features if self._img_emb_size != hidden_size: # Run a linear projection of `hidden_size` image_features = tf.reshape( image_features, shape=[-1, self._img_emb_size]) image_features = tf.layers.dense( image_features, hidden_size, name='img_projection') image_features = tf.reshape( image_features, shape=[-1, self._img_feature_num, hidden_size]) batch_size = tf.shape(image_features)[0] img_fea = multihead_cross_attention.embedding_postprocessor( image_features, use_token_type=self._use_token_type, token_type_ids=tf.zeros( shape=tf.stack([batch_size, self._img_feature_num]), dtype=tf.int32), token_type_vocab_size=self._token_type_vocab_size, reuse_token_type=tf.AUTO_REUSE, use_position_embeddings=self._model_config.use_position_embeddings, max_position_embeddings=self._model_config.max_position_embeddings, position_embedding_name='img_position_embeddings', dropout_prob=self._model_config.hidden_dropout_prob) return img_fea def __call__(self, is_training, *args, **kwargs): if not is_training: self._model_config.hidden_dropout_prob = 0.0 self._model_config.attention_probs_dropout_prob = 0.0 sub_modules = [] img_fea = self.image_embeddings() start_token_id = 1 if self._img_feature_num > 0 else 0 txt_features, txt_masks = self.text_embeddings(start_token_id) if img_fea is not None: batch_size = tf.shape(img_fea)[0] elif txt_features: batch_size = tf.shape(txt_features[0])[0] else: batch_size = None hidden_size = self._model_config.hidden_size if batch_size is not None: all_features = [] masks = [] cls_emb = tf.get_variable(name='cls_emb', shape=[1, 1, hidden_size]) cls_emb = tf.tile(cls_emb, [batch_size, 1, 1]) all_features.append(cls_emb) mask = tf.ones(shape=tf.stack([batch_size, 1]), dtype=tf.int32) masks.append(mask) if img_fea is not None: all_features.append(img_fea) mask = tf.ones( shape=tf.stack([batch_size, self._img_feature_num]), dtype=tf.int32) masks.append(mask) if txt_features: all_features.extend(txt_features) masks.extend(txt_masks) all_fea = tf.concat(all_features, axis=1) input_mask = tf.concat(masks, axis=1) attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask( from_tensor=all_fea, to_mask=input_mask) hidden_act = get_activation(self._model_config.hidden_act) attention_fea = multihead_cross_attention.transformer_encoder( all_fea, hidden_size=hidden_size, num_hidden_layers=self._model_config.num_hidden_layers, num_attention_heads=self._model_config.num_attention_heads, attention_mask=attention_mask, intermediate_size=self._model_config.intermediate_size, intermediate_act_fn=hidden_act, hidden_dropout_prob=self._model_config.hidden_dropout_prob, attention_probs_dropout_prob=self._model_config .attention_probs_dropout_prob, initializer_range=self._model_config.initializer_range, name='uniter') # shape: [batch_size, seq_length, hidden_size] print('attention_fea:', attention_fea.shape) mm_fea = attention_fea[:, 0, :] # [CLS] feature sub_modules.append(mm_fea) if self._other_features is not None: if self._model_config.HasField('other_feature_dnn'): l2_reg = kwargs['l2_reg'] if 'l2_reg' in kwargs else 0 other_dnn_layer = dnn.DNN(self._model_config.other_feature_dnn, l2_reg, 'other_dnn', is_training) other_fea = other_dnn_layer(self._other_features) else: other_fea = self._other_features sub_modules.append(other_fea) if len(sub_modules) == 1: return sub_modules[0] output = tf.concat(sub_modules, axis=-1) return output