easy_rec/python/model/multi_task_model.py (268 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging from collections import OrderedDict import tensorflow as tf from easy_rec.python.builders import loss_builder from easy_rec.python.layers.dnn import DNN from easy_rec.python.model.rank_model import RankModel from easy_rec.python.protos import tower_pb2 from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel from easy_rec.python.protos.loss_pb2 import LossType if tf.__version__ >= '2.0': tf = tf.compat.v1 class MultiTaskModel(RankModel): def __init__(self, model_config, feature_configs, features, labels=None, is_training=False): super(MultiTaskModel, self).__init__(model_config, feature_configs, features, labels, is_training) self._task_towers = [] self._task_num = None self._label_name_dict = {} def build_predict_graph(self): if not self.has_backbone: raise NotImplementedError( 'method `build_predict_graph` must be implemented when backbone network do not exists' ) model = self._model_config.WhichOneof('model') assert model == 'model_params', '`model_params` must be configured' config = self._model_config.model_params for out in config.outputs: self._outputs.append(out) self._init_towers(config.task_towers) backbone = self.backbone if type(backbone) in (list, tuple): if len(backbone) != len(config.task_towers): raise ValueError( 'The number of backbone outputs and task towers must be equal') task_input_list = backbone else: task_input_list = [backbone] * len(config.task_towers) tower_features = {} for i, task_tower_cfg in enumerate(config.task_towers): tower_name = task_tower_cfg.tower_name with tf.name_scope(tower_name): if task_tower_cfg.HasField('dnn'): tower_dnn = DNN( task_tower_cfg.dnn, self._l2_reg, name=tower_name, is_training=self._is_training) tower_output = tower_dnn(task_input_list[i]) else: tower_output = task_input_list[i] tower_features[tower_name] = tower_output tower_outputs = {} relation_features = {} # bayes network for task_tower_cfg in config.task_towers: tower_name = task_tower_cfg.tower_name with tf.name_scope(tower_name): if task_tower_cfg.HasField('relation_dnn'): relation_dnn = DNN( task_tower_cfg.relation_dnn, self._l2_reg, name=tower_name + '/relation_dnn', is_training=self._is_training) tower_inputs = [tower_features[tower_name]] for relation_tower_name in task_tower_cfg.relation_tower_names: tower_inputs.append(relation_features[relation_tower_name]) relation_input = tf.concat( tower_inputs, axis=-1, name=tower_name + '/relation_input') relation_fea = relation_dnn(relation_input) relation_features[tower_name] = relation_fea else: relation_fea = tower_features[tower_name] output_logits = tf.layers.dense( relation_fea, task_tower_cfg.num_class, kernel_regularizer=self._l2_reg, name=tower_name + '/output') tower_outputs[tower_name] = output_logits self._add_to_prediction_dict(tower_outputs) return self._prediction_dict def _init_towers(self, task_tower_configs): """Init task towers.""" self._task_towers = task_tower_configs self._task_num = len(task_tower_configs) for i, task_tower_config in enumerate(task_tower_configs): assert isinstance(task_tower_config, tower_pb2.TaskTower) or \ isinstance(task_tower_config, tower_pb2.BayesTaskTower), \ 'task_tower_config must be a instance of tower_pb2.TaskTower or tower_pb2.BayesTaskTower' tower_name = task_tower_config.tower_name # For label backward compatibility with list if self._labels is not None: if task_tower_config.HasField('label_name'): label_name = task_tower_config.label_name else: # If label name is not specified, task_tower and label will be matched by order label_name = list(self._labels.keys())[i] logging.info('Task Tower [%s] use label [%s]' % (tower_name, label_name)) assert label_name in self._labels, 'label [%s] must exists in labels' % label_name self._label_name_dict[tower_name] = label_name def _add_to_prediction_dict(self, output): for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name if len(task_tower_cfg.losses) == 0: self._prediction_dict.update( self._output_to_prediction_impl( output[tower_name], loss_type=task_tower_cfg.loss_type, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name)) else: for loss in task_tower_cfg.losses: self._prediction_dict.update( self._output_to_prediction_impl( output[tower_name], loss_type=loss.loss_type, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name)) def build_metric_graph(self, eval_config): """Build metric graph for multi task model.""" for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name for metric in task_tower_cfg.metrics_set: loss_types = {task_tower_cfg.loss_type} if len(task_tower_cfg.losses) > 0: loss_types = {loss.loss_type for loss in task_tower_cfg.losses} self._metric_dict.update( self._build_metric_impl( metric, loss_type=loss_types, label_name=self._label_name_dict[tower_name], num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name)) return self._metric_dict def build_loss_weight(self): loss_weights = OrderedDict() num_loss = 0 for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name losses = task_tower_cfg.losses n = len(losses) if n > 0: loss_weights[tower_name] = [ loss.weight * task_tower_cfg.weight for loss in losses ] num_loss += n else: loss_weights[tower_name] = [task_tower_cfg.weight] num_loss += 1 strategy = self._base_model_config.loss_weight_strategy if strategy == self._base_model_config.Random: weights = tf.random_normal([num_loss]) weights = tf.nn.softmax(weights) i = 0 for k, v in loss_weights.items(): n = len(v) loss_weights[k] = weights[i:i + n] i += n return loss_weights def get_learnt_loss(self, loss_type, name, value): strategy = self._base_model_config.loss_weight_strategy if strategy == self._base_model_config.Uncertainty: uncertainty = tf.Variable( 0, name='%s_loss_weight' % name, dtype=tf.float32) tf.summary.scalar('loss/%s_uncertainty' % name, uncertainty) if loss_type in {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: return 0.5 * tf.exp(-uncertainty) * value + 0.5 * uncertainty else: return tf.exp(-uncertainty) * value + 0.5 * uncertainty else: strategy_name = EasyRecModel.LossWeightStrategy.Name(strategy) raise ValueError('Unsupported loss weight strategy: ' + strategy_name) def build_loss_graph(self): """Build loss graph for multi task model.""" task_loss_weights = self.build_loss_weight() for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name loss_weight = 1.0 if task_tower_cfg.use_sample_weight: loss_weight *= self._sample_weight if hasattr(task_tower_cfg, 'task_space_indicator_label') and \ task_tower_cfg.HasField('task_space_indicator_label'): in_task_space = tf.to_float( self._labels[task_tower_cfg.task_space_indicator_label] > 0) loss_weight = loss_weight * ( task_tower_cfg.in_task_space_weight * in_task_space + task_tower_cfg.out_task_space_weight * (1 - in_task_space)) if task_tower_cfg.HasField('task_space_indicator_name') and \ task_tower_cfg.HasField('task_space_indicator_value'): in_task_space = tf.to_float( tf.equal( self._feature_dict[task_tower_cfg.task_space_indicator_name], task_tower_cfg.task_space_indicator_value)) loss_weight = loss_weight * ( task_tower_cfg.in_task_space_weight * in_task_space + task_tower_cfg.out_task_space_weight * (1 - in_task_space)) task_loss_weight = task_loss_weights[tower_name] loss_dict = {} losses = task_tower_cfg.losses if len(losses) == 0: loss_dict = self._build_loss_impl( task_tower_cfg.loss_type, label_name=self._label_name_dict[tower_name], loss_weight=loss_weight, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name) for loss_name in loss_dict.keys(): loss_dict[loss_name] = loss_dict[loss_name] * task_loss_weight[0] else: calibrate_loss = [] for loss in losses: if loss.loss_type == LossType.ORDER_CALIBRATE_LOSS: y_t = self._prediction_dict['probs_%s' % tower_name] for relation_tower_name in task_tower_cfg.relation_tower_names: y_rt = self._prediction_dict['probs_%s' % relation_tower_name] cali_loss = tf.reduce_mean(tf.nn.relu(y_t - y_rt)) calibrate_loss.append(cali_loss * loss.weight) logging.info('calibrate loss: %s -> %s' % (relation_tower_name, tower_name)) continue loss_param = loss.WhichOneof('loss_param') if loss_param is not None: loss_param = getattr(loss, loss_param) loss_ops = self._build_loss_impl( loss.loss_type, label_name=self._label_name_dict[tower_name], loss_weight=loss_weight, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name, loss_name=loss.loss_name, loss_param=loss_param) for i, loss_name in enumerate(loss_ops): loss_value = loss_ops[loss_name] if loss.learn_loss_weight: loss_dict[loss_name] = self.get_learnt_loss( loss.loss_type, loss_name, loss_value) else: loss_dict[loss_name] = loss_value * task_loss_weight[i] if calibrate_loss: cali_loss = tf.add_n(calibrate_loss) loss_dict['order_calibrate_loss'] = cali_loss tf.summary.scalar('loss/order_calibrate_loss', cali_loss) self._loss_dict.update(loss_dict) kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict, self._labels, self._feature_dict) self._loss_dict.update(kd_loss_dict) return self._loss_dict def get_outputs(self): outputs = [] if self._outputs: outputs.extend(self._outputs) for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name if len(task_tower_cfg.losses) == 0: outputs.extend( self._get_outputs_impl( task_tower_cfg.loss_type, task_tower_cfg.num_class, suffix='_%s' % tower_name)) else: for loss in task_tower_cfg.losses: if loss.loss_type == LossType.ORDER_CALIBRATE_LOSS: continue outputs.extend( self._get_outputs_impl( loss.loss_type, task_tower_cfg.num_class, suffix='_%s' % tower_name)) return list(set(outputs))