easy_rec/python/model/match_model.py (285 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import tensorflow as tf
from easy_rec.python.builders import loss_builder
from easy_rec.python.model.easy_rec_model import EasyRecModel
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.protos.simi_pb2 import Similarity
if tf.__version__ >= '2.0':
tf = tf.compat.v1
losses = tf.losses
class MatchModel(EasyRecModel):
def __init__(self,
model_config,
feature_configs,
features,
labels=None,
is_training=False):
super(MatchModel, self).__init__(model_config, feature_configs, features,
labels, is_training)
self._loss_type = self._model_config.loss_type
self._num_class = self._model_config.num_class
if self._loss_type == LossType.CLASSIFICATION:
assert self._num_class == 1
if self._loss_type in [LossType.CLASSIFICATION, LossType.L2_LOSS]:
self._is_point_wise = True
logging.info('Use point wise dssm.')
else:
self._is_point_wise = False
logging.info('Use list wise dssm.')
cls_mem = self._model_config.WhichOneof('model')
sub_model_config = getattr(self._model_config, cls_mem)
self._item_ids = None
assert sub_model_config is not None, 'sub_model_config undefined: model_cls = %s' % cls_mem
if getattr(sub_model_config, 'item_id', '') != '':
logging.info('item_id feature is: %s' % sub_model_config.item_id)
self._item_ids = features[sub_model_config.item_id]
def _mask_in_batch(self, logits):
batch_size = tf.shape(logits)[0]
if getattr(self._model_config, 'ignore_in_batch_neg_sam', False):
in_batch = logits[:, :batch_size] - (
1 - tf.diag(tf.ones([batch_size], dtype=tf.float32))) * 1e32
return tf.concat([in_batch, logits[:, batch_size:]], axis=1)
else:
if self._item_ids is not None:
mask_in_batch_neg = tf.to_float(
tf.equal(self._item_ids[None, :batch_size],
self._item_ids[:batch_size, None])) - tf.diag(
tf.ones([batch_size], dtype=tf.float32))
tf.summary.scalar('in_batch_neg_conflict',
tf.reduce_sum(mask_in_batch_neg))
return tf.concat([
logits[:, :batch_size] - mask_in_batch_neg * 1e32,
logits[:, batch_size:]],
axis=1) # yapf: disable
else:
return logits
def _list_wise_sim(self, user_emb, item_emb):
batch_size = tf.shape(user_emb)[0]
hard_neg_indices = self._feature_dict.get('hard_neg_indices', None)
if hard_neg_indices is not None:
logging.info('With hard negative examples')
noclk_size = tf.shape(hard_neg_indices)[0]
# pos_item_emb, neg_item_emb, hard_neg_item_emb = tf.split(
# item_emb, [batch_size, -1, noclk_size], axis=0)
simple_item_emb, hard_neg_item_emb = tf.split(
item_emb, [-1, noclk_size], axis=0)
else:
# pos_item_emb = item_emb[:batch_size]
# neg_item_emb = item_emb[batch_size:]
simple_item_emb = item_emb
# pos_user_item_sim = tf.reduce_sum(
# tf.multiply(user_emb, pos_item_emb), axis=1, keep_dims=True)
# neg_user_item_sim = tf.matmul(user_emb, tf.transpose(neg_item_emb))
# simple_user_item_sim = tf.matmul(user_emb, tf.transpose(simple_item_emb))
_mode = os.environ['tf.estimator.mode']
if _mode == tf.estimator.ModeKeys.PREDICT:
simple_user_item_sim = tf.reduce_sum(
tf.multiply(user_emb, simple_item_emb), axis=1, keep_dims=True)
else:
simple_user_item_sim = tf.matmul(user_emb, tf.transpose(simple_item_emb))
if hard_neg_indices is None:
return simple_user_item_sim
else:
user_emb_expand = tf.gather(user_emb, hard_neg_indices[:, 0])
hard_neg_user_item_sim = tf.reduce_sum(
tf.multiply(user_emb_expand, hard_neg_item_emb), axis=1)
max_num_neg = tf.reduce_max(hard_neg_indices[:, 1]) + 1
hard_neg_shape = tf.stack([tf.to_int64(batch_size), max_num_neg])
hard_neg_sim = tf.scatter_nd(hard_neg_indices, hard_neg_user_item_sim,
hard_neg_shape)
hard_neg_mask = tf.scatter_nd(
hard_neg_indices,
tf.ones_like(hard_neg_user_item_sim, dtype=tf.float32),
shape=hard_neg_shape)
# set tail positions to -1e32, so that after exp(x), will be zero
hard_neg_user_item_sim = hard_neg_sim - (1 - hard_neg_mask) * 1e32
# user_item_sim = [pos_user_item_sim, neg_user_item_sim]
# if hard_neg_indices is not None:
# user_item_sim.append(hard_neg_user_item_sim)
# return tf.concat(user_item_sim, axis=1)
return tf.concat([simple_user_item_sim, hard_neg_user_item_sim], axis=1)
def _point_wise_sim(self, user_emb, item_emb):
user_item_sim = tf.reduce_sum(
tf.multiply(user_emb, item_emb), axis=1, keep_dims=True)
return user_item_sim
def sim(self, user_emb, item_emb):
# Name the outputs of the user tower and the item tower, i.e. the inputs of the
# simularity operation.
# Explicit names of these nodes are necessary for some online recall systems like
# BasicEngine to split up the predicting graph into different clusters.
user_emb = tf.identity(user_emb, 'user_tower_emb')
item_emb = tf.identity(item_emb, 'item_tower_emb')
if self._is_point_wise:
return self._point_wise_sim(user_emb, item_emb)
else:
return self._list_wise_sim(user_emb, item_emb)
def norm(self, fea):
fea_norm = tf.nn.l2_normalize(fea, axis=-1)
return fea_norm
def build_predict_graph(self):
if not self.has_backbone:
raise NotImplementedError(
'method `build_predict_graph` must be implemented when you donot use backbone network'
)
model = self._model_config.WhichOneof('model')
assert model == 'model_params', '`model_params` must be configured'
model_params = self._model_config.model_params
for out in model_params.outputs:
self._outputs.append(out)
output = self.backbone
user_tower_emb = output[model_params.user_tower_idx_in_output]
item_tower_emb = output[model_params.item_tower_idx_in_output]
if model_params.simi_func == Similarity.COSINE:
user_tower_emb = self.norm(user_tower_emb)
item_tower_emb = self.norm(item_tower_emb)
temperature = model_params.temperature
else:
temperature = 1.0
user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature
if model_params.scale_simi:
sim_w = tf.get_variable(
'sim_w',
dtype=tf.float32,
shape=(1),
initializer=tf.ones_initializer())
sim_b = tf.get_variable(
'sim_b',
dtype=tf.float32,
shape=(1),
initializer=tf.zeros_initializer())
y_pred = user_item_sim * tf.abs(sim_w) + sim_b
else:
y_pred = user_item_sim
if self._is_point_wise:
y_pred = tf.reshape(y_pred, [-1])
if self._loss_type == LossType.CLASSIFICATION:
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
y_pred = self._mask_in_batch(y_pred)
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
else:
self._prediction_dict['y'] = y_pred
self._prediction_dict['user_tower_emb'] = user_tower_emb
self._prediction_dict['item_tower_emb'] = item_tower_emb
self._prediction_dict['user_emb'] = tf.reduce_join(
tf.as_string(user_tower_emb), axis=-1, separator=',')
self._prediction_dict['item_emb'] = tf.reduce_join(
tf.as_string(item_tower_emb), axis=-1, separator=',')
return self._prediction_dict
def build_loss_graph(self):
if self._is_point_wise:
return self._build_point_wise_loss_graph()
else:
return self._build_list_wise_loss_graph()
def _build_list_wise_loss_graph(self):
if self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
batch_size = tf.shape(self._prediction_dict['probs'])[0]
indices = tf.range(batch_size)
indices = tf.concat([indices[:, None], indices[:, None]], axis=1)
hit_prob = tf.gather_nd(
self._prediction_dict['probs'][:batch_size, :batch_size], indices)
sample_weights = tf.cast(tf.squeeze(self._sample_weight), tf.float32)
self._loss_dict['cross_entropy_loss'] = -tf.reduce_mean(
tf.log(hit_prob + 1e-12) *
sample_weights) / tf.reduce_mean(sample_weights)
logging.info('softmax cross entropy loss is used')
user_features = self._prediction_dict['user_tower_emb']
pos_item_emb = self._prediction_dict['item_tower_emb'][:batch_size]
pos_simi = tf.reduce_sum(user_features * pos_item_emb, axis=1)
# if pos_simi < 0, produce loss
reg_pos_loss = tf.nn.relu(-pos_simi)
self._loss_dict['reg_pos_loss'] = tf.reduce_mean(
reg_pos_loss * sample_weights) / tf.reduce_mean(sample_weights)
# the AMM loss for DAT model
if all([
k in self._prediction_dict.keys() for k in
['augmented_p_u', 'augmented_p_i', 'augmented_a_u', 'augmented_a_i']
]):
self._loss_dict[
'amm_loss_u'] = self._model_config.amm_u_weight * tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_u'] -
self._prediction_dict['augmented_p_i'][:batch_size]) *
sample_weights) / tf.reduce_mean(sample_weights)
self._loss_dict[
'amm_loss_i'] = self._model_config.amm_i_weight * tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_i'][:batch_size] -
self._prediction_dict['augmented_p_u']) *
sample_weights) / tf.reduce_mean(sample_weights)
else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))
return self._loss_dict
def _build_point_wise_loss_graph(self):
label = list(self._labels.values())[0]
if self._loss_type == LossType.CLASSIFICATION:
pred = self._prediction_dict['logits']
loss_name = 'cross_entropy_loss'
elif self._loss_type == LossType.L2_LOSS:
pred = self._prediction_dict['y']
loss_name = 'l2_loss'
else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))
kwargs = {'loss_name': loss_name}
self._loss_dict[loss_name] = loss_builder.build(
self._loss_type,
label=label,
pred=pred,
loss_weight=self._sample_weight,
**kwargs)
# build kd loss
kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict,
self._labels, self._feature_dict)
self._loss_dict.update(kd_loss_dict)
return self._loss_dict
def build_metric_graph(self, eval_config):
if self._is_point_wise:
return self._build_point_wise_metric_graph(eval_config)
else:
return self._build_list_wise_metric_graph(eval_config)
def _build_list_wise_metric_graph(self, eval_config):
from easy_rec.python.core.easyrec_metrics import metrics_tf
logits = self._prediction_dict['logits']
# label = tf.zeros_like(logits[:, :1], dtype=tf.int64)
batch_size = tf.shape(logits)[0]
label = tf.cast(tf.range(batch_size), tf.int64)
indices = tf.range(batch_size)
indices = tf.concat([indices[:, None], indices[:, None]], axis=1)
pos_item_sim = tf.gather_nd(logits[:batch_size, :batch_size], indices)
metric_dict = {}
for metric in eval_config.metrics_set:
if metric.WhichOneof('metric') == 'recall_at_topk':
metric_dict['recall@%d' %
metric.recall_at_topk.topk] = metrics_tf.recall_at_k(
label, logits, metric.recall_at_topk.topk)
logits_v2 = tf.concat([pos_item_sim[:, None], logits[:, batch_size:]],
axis=1)
labels_v2 = tf.zeros_like(logits_v2[:, :1], dtype=tf.int64)
metric_dict['recall_neg_sam@%d' %
metric.recall_at_topk.topk] = metrics_tf.recall_at_k(
labels_v2, logits_v2, metric.recall_at_topk.topk)
metric_dict['recall_in_batch@%d' %
metric.recall_at_topk.topk] = metrics_tf.recall_at_k(
label, logits[:, :batch_size],
metric.recall_at_topk.topk)
else:
raise ValueError('invalid metric type: %s' % str(metric))
return metric_dict
def _build_point_wise_metric_graph(self, eval_config):
from easy_rec.python.core.easyrec_metrics import metrics_tf
metric_dict = {}
label = list(self._labels.values())[0]
for metric in eval_config.metrics_set:
if metric.WhichOneof('metric') == 'auc':
assert self._loss_type == LossType.CLASSIFICATION
metric_dict['auc'] = metrics_tf.auc(label,
self._prediction_dict['probs'])
elif metric.WhichOneof('metric') == 'mean_absolute_error':
assert self._loss_type == LossType.L2_LOSS
metric_dict['mean_absolute_error'] = metrics_tf.mean_absolute_error(
tf.to_float(label), self._prediction_dict['y'])
else:
raise ValueError('invalid metric type: %s' % str(metric))
return metric_dict
def get_outputs(self):
if not self.has_backbone:
raise NotImplementedError(
'could not call get_outputs on abstract class MatchModel')
if self._loss_type == LossType.CLASSIFICATION:
return [
'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
'item_tower_emb'
]
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
self._prediction_dict['logits'] = tf.squeeze(
self._prediction_dict['logits'], axis=-1)
self._prediction_dict['probs'] = tf.nn.sigmoid(
self._prediction_dict['logits'])
return [
'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
'item_tower_emb'
]
elif self._loss_type == LossType.L2_LOSS:
return ['y', 'user_emb', 'item_emb', 'user_tower_emb', 'item_tower_emb']
else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))