def build_metric_graph()

in easy_rec/python/model/mind.py [0:0]


  def build_metric_graph(self, eval_config):
    from easy_rec.python.core.easyrec_metrics import metrics_tf as metrics
    # build interest metric
    interest_simi, capsule_simi = self._build_interest_simi()
    metric_dict = {
        'interest_similarity': metrics.mean(interest_simi),
        'capsule_similarity': metrics.mean(capsule_simi)
    }
    if self._is_point_wise:
      metric_dict.update(self._build_point_wise_metric_graph(eval_config))
      return metric_dict

    recall_at_topks = []
    for metric in eval_config.metrics_set:
      if metric.WhichOneof('metric') == 'recall_at_topk':
        assert self._loss_type in [
            LossType.CLASSIFICATION, LossType.SOFTMAX_CROSS_ENTROPY
        ]
        if metric.recall_at_topk.topk not in recall_at_topks:
          recall_at_topks.append(metric.recall_at_topk.topk)

    # compute interest recall
    # [batch_size, num_interests, embed_dim]
    user_interests = self._prediction_dict['user_interests']
    # [?, embed_dim]
    item_tower_emb = self._prediction_dict['item_tower_emb']
    batch_size = tf.shape(user_interests)[0]
    # [?, 2] first dimension is the sample_id in batch
    # second dimension is the neg_id with respect to the sample
    hard_neg_indices = self._feature_dict.get('hard_neg_indices', None)

    if hard_neg_indices is not None:
      logging.info('With hard negative examples')
      noclk_size = tf.shape(hard_neg_indices)[0]
      simple_item_emb, hard_neg_item_emb = tf.split(
          item_tower_emb, [-1, noclk_size], axis=0)
    else:
      simple_item_emb = item_tower_emb
      hard_neg_item_emb = None

    # batch_size num_interest sample_neg_num
    simple_item_sim = tf.einsum('bhe,ne->bhn', user_interests, simple_item_emb)
    # batch_size sample_neg_num
    simple_item_sim = tf.reduce_max(simple_item_sim, axis=1)
    simple_lbls = tf.cast(tf.range(tf.shape(user_interests)[0]), tf.int64)

    # labels = tf.zeros_like(logits[:, :1], dtype=tf.int64)
    pos_indices = tf.range(batch_size)
    pos_indices = tf.concat([pos_indices[:, None], pos_indices[:, None]],
                            axis=1)
    pos_item_sim = tf.gather_nd(simple_item_sim[:batch_size, :batch_size],
                                pos_indices)

    simple_item_sim_v2 = tf.concat(
        [pos_item_sim[:, None], simple_item_sim[:, batch_size:]], axis=1)
    simple_lbls_v2 = tf.zeros_like(simple_item_sim_v2[:, :1], dtype=tf.int64)

    for topk in recall_at_topks:
      metric_dict['interests_recall@%d' % topk] = metrics.recall_at_k(
          labels=simple_lbls,
          predictions=simple_item_sim,
          k=topk,
          name='interests_recall_at_%d' % topk)
      metric_dict['interests_neg_sam_recall@%d' % topk] = metrics.recall_at_k(
          labels=simple_lbls_v2,
          predictions=simple_item_sim_v2,
          k=topk,
          name='interests_recall_neg_sam_at_%d' % topk)

    logits = self._prediction_dict['logits']
    pos_item_logits = tf.gather_nd(logits[:batch_size, :batch_size],
                                   pos_indices)
    logits_v2 = tf.concat([pos_item_logits[:, None], logits[:, batch_size:]],
                          axis=1)
    labels_v2 = tf.zeros_like(logits_v2[:, :1], dtype=tf.int64)

    for topk in recall_at_topks:
      metric_dict['recall@%d' % topk] = metrics.recall_at_k(
          labels=simple_lbls,
          predictions=logits,
          k=topk,
          name='recall_at_%d' % topk)
      metric_dict['recall_neg_sam@%d' % topk] = metrics.recall_at_k(
          labels=labels_v2,
          predictions=logits_v2,
          k=topk,
          name='recall_neg_sam_at_%d' % topk)
      eval_logits = logits[:, :batch_size]
      eval_logits = tf.cond(
          batch_size < topk, lambda: tf.pad(
              eval_logits, [[0, 0], [0, topk - batch_size]],
              mode='CONSTANT',
              constant_values=-1e32,
              name='pad_eval_logits'), lambda: eval_logits)
      metric_dict['recall_in_batch@%d' % topk] = metrics.recall_at_k(
          labels=simple_lbls,
          predictions=eval_logits,
          k=topk,
          name='recall_in_batch_at_%d' % topk)

    # batch_size num_interest
    if hard_neg_indices is not None:
      hard_neg_user_emb = tf.gather(user_interests, hard_neg_indices[:, 0])
      hard_neg_sim = tf.einsum('nhe,ne->nh', hard_neg_user_emb,
                               hard_neg_item_emb)
      hard_neg_sim = tf.reduce_max(hard_neg_sim, axis=1)
      max_num_neg = tf.reduce_max(hard_neg_indices[:, 1]) + 1
      hard_neg_shape = tf.stack([tf.to_int64(batch_size), max_num_neg])
      hard_neg_mask = tf.scatter_nd(
          hard_neg_indices,
          tf.ones_like(hard_neg_sim, dtype=tf.float32),
          shape=hard_neg_shape)
      hard_neg_sim = tf.scatter_nd(hard_neg_indices, hard_neg_sim,
                                   hard_neg_shape)
      hard_neg_sim = hard_neg_sim - (1 - hard_neg_mask) * 1e32

      hard_logits = tf.concat([pos_item_logits[:, None], hard_neg_sim], axis=1)
      hard_lbls = tf.zeros_like(hard_logits[:, :1], dtype=tf.int64)
      metric_dict['hard_neg_acc'] = metrics.accuracy(
          hard_lbls, tf.argmax(hard_logits, axis=1))

    return metric_dict