easy_rec/python/model/collaborative_metric_learning.py (156 lines of code) (raw):
import tensorflow as tf
from easy_rec.python.core.metrics import metric_learning_average_precision_at_k
from easy_rec.python.core.metrics import metric_learning_recall_at_k
from easy_rec.python.layers import dnn
from easy_rec.python.layers.common_layers import highway
from easy_rec.python.loss.circle_loss import circle_loss
from easy_rec.python.loss.multi_similarity import ms_loss
from easy_rec.python.model.easy_rec_model import EasyRecModel
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.utils.activation import gelu
from easy_rec.python.utils.proto_util import copy_obj
from easy_rec.python.protos.collaborative_metric_learning_pb2 import CoMetricLearningI2I as MetricLearningI2IConfig # NOQA
if tf.__version__ >= '2.0':
tf = tf.compat.v1
class CoMetricLearningI2I(EasyRecModel):
def __init__(
self,
model_config, # pipeline.model_config
feature_configs, # pipeline.feature_configs
features, # same as model_fn input
labels=None,
is_training=False):
super(CoMetricLearningI2I, self).__init__(model_config, feature_configs,
features, labels, is_training)
model = self._model_config.WhichOneof('model')
assert model == 'metric_learning', 'invalid model config: %s' % model
self._loss_type = self._model_config.loss_type
loss_type_name = LossType.Name(self._loss_type).lower()
self._model_config = self._model_config.metric_learning
assert isinstance(self._model_config, MetricLearningI2IConfig)
model_loss = self._model_config.WhichOneof('loss').lower()
assert model_loss == loss_type_name, 'invalid loss type: %s' % model_loss
if self._loss_type == LossType.CIRCLE_LOSS:
self.loss = self._model_config.circle_loss
elif self._loss_type == LossType.MULTI_SIMILARITY_LOSS:
self.loss = self._model_config.multi_similarity_loss
else:
raise ValueError('unsupported loss type: %s' %
LossType.Name(self._loss_type))
if not self.has_backbone:
self._highway_features = {}
self._highway_num = len(self._model_config.highway)
for _id in range(self._highway_num):
highway_cfg = self._model_config.highway[_id]
highway_feature, _ = self._input_layer(self._feature_dict,
highway_cfg.input)
self._highway_features[highway_cfg.input] = highway_feature
self.input_features = []
if self._model_config.HasField('input'):
input_feature, _ = self._input_layer(self._feature_dict,
self._model_config.input)
self.input_features.append(input_feature)
self.dnn = copy_obj(self._model_config.dnn)
if self._labels is not None:
if self._model_config.HasField('session_id'):
self.session_ids = self._labels.pop(self._model_config.session_id)
else:
self.session_ids = None
assert len(self._labels) > 0
self.labels = list(self._labels.values())[0]
if self._model_config.HasField('sample_id'):
self.sample_id = self._model_config.sample_id
else:
self.sample_id = None
def build_predict_graph(self):
if self.has_backbone:
tower_emb = self.backbone
else:
for _id in range(self._highway_num):
highway_cfg = self._model_config.highway[_id]
highway_fea = tf.layers.batch_normalization(
self._highway_features[highway_cfg.input],
training=self._is_training,
trainable=True,
name='highway_%s_bn' % highway_cfg.input)
highway_fea = highway(
highway_fea,
highway_cfg.emb_size,
activation=gelu,
scope='highway_%s' % _id)
print('highway_fea: ', highway_fea)
self.input_features.append(highway_fea)
feature = tf.concat(self.input_features, axis=1)
num_dnn_layer = len(self.dnn.hidden_units)
last_hidden = self.dnn.hidden_units.pop()
dnn_net = dnn.DNN(self.dnn, self._l2_reg, 'dnn', self._is_training)
net_output = dnn_net(feature)
tower_emb = tf.layers.dense(
inputs=net_output,
units=last_hidden,
kernel_regularizer=self._l2_reg,
name='dnn/dnn_%d' % (num_dnn_layer - 1))
if self._model_config.output_l2_normalized_emb:
norm_emb = tf.nn.l2_normalize(tower_emb, axis=-1)
self._prediction_dict['norm_emb'] = norm_emb
self._prediction_dict['norm_embedding'] = tf.reduce_join(
tf.as_string(norm_emb), axis=-1, separator=',')
self._prediction_dict['float_emb'] = tower_emb
self._prediction_dict['embedding'] = tf.reduce_join(
tf.as_string(tower_emb), axis=-1, separator=',')
if self.sample_id is not None and self.sample_id in self._feature_dict:
self._prediction_dict['sample_id'] = tf.identity(
self._feature_dict[self.sample_id])
return self._prediction_dict
def build_loss_graph(self):
emb = self._prediction_dict['float_emb']
emb_normed = self._model_config.output_l2_normalized_emb
norm_emb = self._prediction_dict['norm_emb'] if emb_normed else emb
if self._loss_type == LossType.CIRCLE_LOSS:
self._loss_dict['circle_loss'] = circle_loss(
norm_emb,
self.labels,
self.session_ids,
self.loss.margin,
self.loss.gamma,
embed_normed=emb_normed)
elif self._loss_type == LossType.MULTI_SIMILARITY_LOSS:
self._loss_dict['ms_loss'] = ms_loss(
norm_emb,
self.labels,
self.session_ids,
self.loss.alpha,
self.loss.beta,
self.loss.lamb,
self.loss.eps,
embed_normed=emb_normed)
else:
raise ValueError('invalid loss type: %s' % LossType.Name(self._loss_type))
return self._loss_dict
def get_outputs(self):
outputs = ['embedding', 'float_emb']
if self.sample_id is not None and 'sample_id' in self._prediction_dict:
outputs.append('sample_id')
if self._model_config.output_l2_normalized_emb:
outputs.append('norm_embedding')
outputs.append('norm_emb')
return outputs
def build_metric_graph(self, eval_config):
metric_dict = {}
recall_at_k = []
precision_at_k = []
for metric in eval_config.metrics_set:
if metric.WhichOneof('metric') == 'recall_at_topk':
recall_at_k.append(metric.recall_at_topk.topk)
elif metric.WhichOneof('metric') == 'precision_at_topk':
precision_at_k.append(metric.precision_at_topk.topk)
emb = self._prediction_dict['float_emb']
if len(recall_at_k) > 0:
metric_dict.update(
metric_learning_recall_at_k(recall_at_k, emb, self.labels,
self.session_ids))
if len(precision_at_k) > 0:
metric_dict.update(
metric_learning_average_precision_at_k(precision_at_k, emb,
self.labels, self.session_ids))
return metric_dict