easy_rec/python/model/rocket_launching.py (174 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import tensorflow as tf from easy_rec.python.builders import loss_builder from easy_rec.python.layers import dnn from easy_rec.python.model.rank_model import RankModel from easy_rec.python.protos.loss_pb2 import LossType from easy_rec.python.protos.simi_pb2 import Similarity from easy_rec.python.protos.rocket_launching_pb2 import RocketLaunching as RocketLaunchingConfig # NOQA if tf.__version__ >= '2.0': tf = tf.compat.v1 class RocketLaunching(RankModel): def __init__(self, model_config, feature_configs, features, labels=None, is_training=False): super(RocketLaunching, self).__init__(model_config, feature_configs, features, labels, is_training) assert self._model_config.WhichOneof('model') == 'rocket_launching', \ 'invalid model config: %s' % self._model_config.WhichOneof('model') self._model_config = self._model_config.rocket_launching assert isinstance(self._model_config, RocketLaunchingConfig) if self._labels is not None: self._label_name = list(self._labels.keys())[0] self._features, _ = self._input_layer(self._feature_dict, 'all') def sim(self, feature_emb1, feature_emb2): emb1_emb2_sim = tf.reduce_sum( tf.multiply(feature_emb1, feature_emb2), axis=1, keepdims=True) return emb1_emb2_sim def norm(self, fea): fea_norm = tf.nn.l2_normalize(fea, axis=1) return fea_norm def feature_based_sim(self, feature_based_distillation, i, j): booster_feature_no_gradient = tf.stop_gradient( self.booster_feature['hidden_layer' + str(j)]) if feature_based_distillation == Similarity.COSINE: booster_feature_no_gradient_norm = self.norm(booster_feature_no_gradient) light_feature_norm = self.norm(self.light_feature['hidden_layer' + str(i)]) sim_middle_layer = tf.reduce_mean( self.sim(booster_feature_no_gradient_norm, light_feature_norm)) return sim_middle_layer else: return tf.sqrt( tf.reduce_sum( tf.square(booster_feature_no_gradient - self.light_feature['hidden_layer' + str(i)]))) def build_predict_graph(self): self.hidden_layer_feature_output = self._model_config.feature_based_distillation if self._model_config.HasField('share_dnn'): share_dnn_layer = dnn.DNN(self._model_config.share_dnn, self._l2_reg, 'share_dnn', self._is_training) share_feature = share_dnn_layer(self._features) booster_dnn_layer = dnn.DNN(self._model_config.booster_dnn, self._l2_reg, 'booster_dnn', self._is_training) light_dnn_layer = dnn.DNN(self._model_config.light_dnn, self._l2_reg, 'light_dnn', self._is_training) if self._model_config.HasField('share_dnn'): self.booster_feature = booster_dnn_layer(share_feature, self.hidden_layer_feature_output) input_embedding_stop_gradient = tf.stop_gradient(share_feature) self.light_feature = light_dnn_layer(input_embedding_stop_gradient, self.hidden_layer_feature_output) else: self.booster_feature = booster_dnn_layer(self._features, self.hidden_layer_feature_output) input_embedding_stop_gradient = tf.stop_gradient(self._features) self.light_feature = light_dnn_layer(input_embedding_stop_gradient, self.hidden_layer_feature_output) if self._model_config.feature_based_distillation: booster_out = tf.layers.dense( self.booster_feature['hidden_layer_end'], self._num_class, kernel_regularizer=self._l2_reg, name='booster_output') light_out = tf.layers.dense( self.light_feature['hidden_layer_end'], self._num_class, kernel_regularizer=self._l2_reg, name='light_output') else: booster_out = tf.layers.dense( self.booster_feature, self._num_class, kernel_regularizer=self._l2_reg, name='booster_output') light_out = tf.layers.dense( self.light_feature, self._num_class, kernel_regularizer=self._l2_reg, name='light_output') self._prediction_dict.update( self._output_to_prediction_impl( booster_out, self._loss_type, num_class=self._num_class, suffix='_booster')) self._prediction_dict.update( self._output_to_prediction_impl( light_out, self._loss_type, num_class=self._num_class, suffix='_light')) return self._prediction_dict def build_loss_graph(self): logits_booster = self._prediction_dict['logits_booster'] logits_light = self._prediction_dict['logits_light'] self.feature_distillation_function = self._model_config.feature_distillation_function # feature_based_distillation loss if self._model_config.feature_based_distillation: booster_hidden_units = self._model_config.booster_dnn.hidden_units light_hidden_units = self._model_config.light_dnn.hidden_units count = 0 for i, unit_i in enumerate(light_hidden_units): for j, unit_j in enumerate(booster_hidden_units): if light_hidden_units[i] == booster_hidden_units[j]: self._prediction_dict[ 'similarity_' + str(count)] = self.feature_based_sim( self._model_config.feature_based_distillation, i, j) count += 1 break self._loss_dict.update( self._build_loss_impl( LossType.CLASSIFICATION, label_name=self._label_name, loss_weight=self._sample_weight, num_class=self._num_class, suffix='_booster')) self._loss_dict.update( self._build_loss_impl( LossType.CLASSIFICATION, label_name=self._label_name, loss_weight=self._sample_weight, num_class=self._num_class, suffix='_light')) booster_logits_no_grad = tf.stop_gradient(logits_booster) self._loss_dict['hint_loss'] = loss_builder.build( LossType.L2_LOSS, label=booster_logits_no_grad, pred=logits_light, loss_weight=self._sample_weight) if self._model_config.feature_based_distillation: for key, value in self._prediction_dict.items(): if key.startswith('similarity_'): self._loss_dict[key] = -0.1 * value return self._loss_dict else: return self._loss_dict def build_metric_graph(self, eval_config): metric_dict = {} for metric in eval_config.metrics_set: metric_dict.update( self._build_metric_impl( metric, loss_type=LossType.CLASSIFICATION, label_name=self._label_name, num_class=self._num_class, suffix='_light')) metric_dict.update( self._build_metric_impl( metric, loss_type=LossType.CLASSIFICATION, label_name=self._label_name, num_class=self._num_class, suffix='_booster')) return metric_dict def get_outputs(self): outputs = [] outputs.extend( self._get_outputs_impl( self._loss_type, self._num_class, suffix='_light')) outputs.extend( self._get_outputs_impl( self._loss_type, self._num_class, suffix='_booster')) return outputs