easy_rec/python/model/dlrm.py (58 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging import tensorflow as tf from easy_rec.python.layers import dnn from easy_rec.python.model.rank_model import RankModel from easy_rec.python.protos.dlrm_pb2 import DLRM as DLRMConfig # NOQA if tf.__version__ >= '2.0': tf = tf.compat.v1 class DLRM(RankModel): """Implements Deep Learning Recommendation Model for Personalization and Recommendation Systems(FaceBook).""" def __init__(self, model_config, feature_configs, features, labels=None, is_training=False): super(DLRM, self).__init__(model_config, feature_configs, features, labels, is_training) assert model_config.WhichOneof('model') == 'dlrm', \ 'invalid model config: %s' % model_config.WhichOneof('model') self._model_config = model_config.dlrm assert isinstance(self._model_config, DLRMConfig) assert self._input_layer.has_group( 'sparse'), 'sparse group is not specified' _, self._sparse_features = self._input_layer(self._feature_dict, 'sparse') assert self._input_layer.has_group('dense'), 'dense group is not specified' self._dense_feature, _ = self._input_layer(self._feature_dict, 'dense') def build_predict_graph(self): bot_dnn = dnn.DNN(self._model_config.bot_dnn, self._l2_reg, 'bot_dnn', self._is_training) dense_fea = bot_dnn(self._dense_feature) logging.info('arch_interaction_op = %s' % self._model_config.arch_interaction_op) if self._model_config.arch_interaction_op == 'cat': all_fea = tf.concat([dense_fea] + self._sparse_features, axis=1) elif self._model_config.arch_interaction_op == 'dot': assert dense_fea.get_shape()[1] == self._sparse_features[0].get_shape()[1], \ 'bot_dnn last hidden[%d] != sparse feature embedding_dim[%d]' % ( dense_fea.get_shape()[1], self._sparse_features[0].get_shape()[1]) all_feas = [dense_fea] + self._sparse_features all_feas = [x[:, None, :] for x in all_feas] all_feas = tf.concat(all_feas, axis=1) num_fea = all_feas.get_shape()[1] interaction = tf.einsum('bne,bme->bnm', all_feas, all_feas) offset = 0 if self._model_config.arch_interaction_itself else 1 upper_tri = [] for i in range(num_fea): upper_tri.append(interaction[:, i, (i + offset):num_fea]) upper_tri = tf.concat(upper_tri, axis=1) concat_feas = [upper_tri] + self._sparse_features if self._model_config.arch_with_dense_feature: concat_feas.append(dense_fea) all_fea = tf.concat(concat_feas, axis=1) top_dnn = dnn.DNN(self._model_config.top_dnn, self._l2_reg, 'top_dnn', self._is_training) all_fea = top_dnn(all_fea) logits = tf.layers.dense( all_fea, 1, kernel_regularizer=self._l2_reg, name='output') self._add_to_prediction_dict(logits) return self._prediction_dict