in easy_rec/python/model/mind.py [0:0]
def build_metric_graph(self, eval_config):
from easy_rec.python.core.easyrec_metrics import metrics_tf as metrics
# build interest metric
interest_simi, capsule_simi = self._build_interest_simi()
metric_dict = {
'interest_similarity': metrics.mean(interest_simi),
'capsule_similarity': metrics.mean(capsule_simi)
}
if self._is_point_wise:
metric_dict.update(self._build_point_wise_metric_graph(eval_config))
return metric_dict
recall_at_topks = []
for metric in eval_config.metrics_set:
if metric.WhichOneof('metric') == 'recall_at_topk':
assert self._loss_type in [
LossType.CLASSIFICATION, LossType.SOFTMAX_CROSS_ENTROPY
]
if metric.recall_at_topk.topk not in recall_at_topks:
recall_at_topks.append(metric.recall_at_topk.topk)
# compute interest recall
# [batch_size, num_interests, embed_dim]
user_interests = self._prediction_dict['user_interests']
# [?, embed_dim]
item_tower_emb = self._prediction_dict['item_tower_emb']
batch_size = tf.shape(user_interests)[0]
# [?, 2] first dimension is the sample_id in batch
# second dimension is the neg_id with respect to the sample
hard_neg_indices = self._feature_dict.get('hard_neg_indices', None)
if hard_neg_indices is not None:
logging.info('With hard negative examples')
noclk_size = tf.shape(hard_neg_indices)[0]
simple_item_emb, hard_neg_item_emb = tf.split(
item_tower_emb, [-1, noclk_size], axis=0)
else:
simple_item_emb = item_tower_emb
hard_neg_item_emb = None
# batch_size num_interest sample_neg_num
simple_item_sim = tf.einsum('bhe,ne->bhn', user_interests, simple_item_emb)
# batch_size sample_neg_num
simple_item_sim = tf.reduce_max(simple_item_sim, axis=1)
simple_lbls = tf.cast(tf.range(tf.shape(user_interests)[0]), tf.int64)
# labels = tf.zeros_like(logits[:, :1], dtype=tf.int64)
pos_indices = tf.range(batch_size)
pos_indices = tf.concat([pos_indices[:, None], pos_indices[:, None]],
axis=1)
pos_item_sim = tf.gather_nd(simple_item_sim[:batch_size, :batch_size],
pos_indices)
simple_item_sim_v2 = tf.concat(
[pos_item_sim[:, None], simple_item_sim[:, batch_size:]], axis=1)
simple_lbls_v2 = tf.zeros_like(simple_item_sim_v2[:, :1], dtype=tf.int64)
for topk in recall_at_topks:
metric_dict['interests_recall@%d' % topk] = metrics.recall_at_k(
labels=simple_lbls,
predictions=simple_item_sim,
k=topk,
name='interests_recall_at_%d' % topk)
metric_dict['interests_neg_sam_recall@%d' % topk] = metrics.recall_at_k(
labels=simple_lbls_v2,
predictions=simple_item_sim_v2,
k=topk,
name='interests_recall_neg_sam_at_%d' % topk)
logits = self._prediction_dict['logits']
pos_item_logits = tf.gather_nd(logits[:batch_size, :batch_size],
pos_indices)
logits_v2 = tf.concat([pos_item_logits[:, None], logits[:, batch_size:]],
axis=1)
labels_v2 = tf.zeros_like(logits_v2[:, :1], dtype=tf.int64)
for topk in recall_at_topks:
metric_dict['recall@%d' % topk] = metrics.recall_at_k(
labels=simple_lbls,
predictions=logits,
k=topk,
name='recall_at_%d' % topk)
metric_dict['recall_neg_sam@%d' % topk] = metrics.recall_at_k(
labels=labels_v2,
predictions=logits_v2,
k=topk,
name='recall_neg_sam_at_%d' % topk)
eval_logits = logits[:, :batch_size]
eval_logits = tf.cond(
batch_size < topk, lambda: tf.pad(
eval_logits, [[0, 0], [0, topk - batch_size]],
mode='CONSTANT',
constant_values=-1e32,
name='pad_eval_logits'), lambda: eval_logits)
metric_dict['recall_in_batch@%d' % topk] = metrics.recall_at_k(
labels=simple_lbls,
predictions=eval_logits,
k=topk,
name='recall_in_batch_at_%d' % topk)
# batch_size num_interest
if hard_neg_indices is not None:
hard_neg_user_emb = tf.gather(user_interests, hard_neg_indices[:, 0])
hard_neg_sim = tf.einsum('nhe,ne->nh', hard_neg_user_emb,
hard_neg_item_emb)
hard_neg_sim = tf.reduce_max(hard_neg_sim, axis=1)
max_num_neg = tf.reduce_max(hard_neg_indices[:, 1]) + 1
hard_neg_shape = tf.stack([tf.to_int64(batch_size), max_num_neg])
hard_neg_mask = tf.scatter_nd(
hard_neg_indices,
tf.ones_like(hard_neg_sim, dtype=tf.float32),
shape=hard_neg_shape)
hard_neg_sim = tf.scatter_nd(hard_neg_indices, hard_neg_sim,
hard_neg_shape)
hard_neg_sim = hard_neg_sim - (1 - hard_neg_mask) * 1e32
hard_logits = tf.concat([pos_item_logits[:, None], hard_neg_sim], axis=1)
hard_lbls = tf.zeros_like(hard_logits[:, :1], dtype=tf.int64)
metric_dict['hard_neg_acc'] = metrics.accuracy(
hard_lbls, tf.argmax(hard_logits, axis=1))
return metric_dict