def build_predict_graph()

in easy_rec/python/model/mind.py [0:0]


  def build_predict_graph(self):
    capsule_layer = CapsuleLayer(self._model_config.capsule_config,
                                 self._is_training)

    if self._model_config.time_id_fea:
      time_id_fea = [
          x[0]
          for x in self._hist_seq_features
          if self._model_config.time_id_fea in x[0].name
      ]
      logging.info('time_id_fea is set(%s), find num: %d' %
                   (self._model_config.time_id_fea, len(time_id_fea)))
    else:
      time_id_fea = []
    time_id_fea = time_id_fea[0] if len(time_id_fea) > 0 else None

    if time_id_fea is not None:
      hist_seq_feas = [
          x[0]
          for x in self._hist_seq_features
          if self._model_config.time_id_fea not in x[0].name
      ]
    else:
      hist_seq_feas = [x[0] for x in self._hist_seq_features]

    # it is assumed that all hist have the same length
    hist_seq_len = self._hist_seq_features[0][1]

    if self._model_config.user_seq_combine == MINDConfig.SUM:
      # sum pooling over the features
      hist_embed_dims = [x.get_shape()[-1] for x in hist_seq_feas]
      for i in range(1, len(hist_embed_dims)):
        assert hist_embed_dims[i] == hist_embed_dims[0], \
            'all hist seq must have the same embedding shape, but: %s' \
            % str(hist_embed_dims)
      hist_seq_feas = tf.add_n(hist_seq_feas) / len(hist_seq_feas)
    else:
      hist_seq_feas = tf.concat(hist_seq_feas, axis=2)

    if self._model_config.HasField('pre_capsule_dnn') and \
        len(self._model_config.pre_capsule_dnn.hidden_units) > 0:
      pre_dnn_layer = dnn.DNN(self._model_config.pre_capsule_dnn, self._l2_reg,
                              'pre_capsule_dnn', self._is_training)
      hist_seq_feas = pre_dnn_layer(hist_seq_feas)

    if time_id_fea is not None:
      assert time_id_fea.get_shape(
      )[-1] == 1, 'time_id must have only embedding_size of 1'
      time_id_mask = tf.sequence_mask(hist_seq_len, tf.shape(time_id_fea)[1])
      time_id_mask = (tf.cast(time_id_mask, tf.float32) * 2 - 1) * 1e32
      time_id_fea = tf.minimum(time_id_fea, time_id_mask[:, :, None])
      hist_seq_feas = hist_seq_feas * tf.nn.softmax(time_id_fea, axis=1)

    tf.summary.histogram('hist_seq_len', hist_seq_len)

    # batch_size x max_k x high_capsule_dim
    high_capsules, num_high_capsules = capsule_layer(hist_seq_feas,
                                                     hist_seq_len)

    tf.summary.histogram('num_high_capsules', num_high_capsules)

    # high_capsules = tf.layers.batch_normalization(
    #     high_capsules, training=self._is_training,
    #     trainable=True, name='capsule_bn')
    # high_capsules = high_capsules * 0.1

    tf.summary.scalar('high_capsules_norm',
                      tf.reduce_mean(tf.norm(high_capsules, axis=-1)))
    tf.summary.scalar('num_high_capsules',
                      tf.reduce_mean(tf.to_float(num_high_capsules)))

    user_features = tf.layers.batch_normalization(
        self._user_features,
        training=self._is_training,
        trainable=True,
        name='user_fea_bn')
    user_dnn = dnn.DNN(self.user_dnn, self._l2_reg, 'user_dnn',
                       self._is_training)
    user_features = user_dnn(user_features)

    tf.summary.scalar('user_features_norm',
                      tf.reduce_mean(tf.norm(self._user_features, axis=-1)))

    # concatenate with user features
    user_features_tile = tf.tile(user_features[:, None, :],
                                 [1, tf.shape(high_capsules)[1], 1])
    user_interests = tf.concat([high_capsules, user_features_tile], axis=2)

    num_concat_dnn_layer = len(self.concat_dnn.hidden_units)
    last_hidden = self.concat_dnn.hidden_units.pop()
    concat_dnn = dnn.DNN(self.concat_dnn, self._l2_reg, 'concat_dnn',
                         self._is_training)
    user_interests = concat_dnn(user_interests)
    user_interests = tf.layers.dense(
        inputs=user_interests,
        units=last_hidden,
        kernel_regularizer=self._l2_reg,
        name='concat_dnn/dnn_%d' % (num_concat_dnn_layer - 1))

    num_item_dnn_layer = len(self.item_dnn.hidden_units)
    last_item_hidden = self.item_dnn.hidden_units.pop()
    item_dnn = dnn.DNN(self.item_dnn, self._l2_reg, 'item_dnn',
                       self._is_training)
    item_tower_emb = item_dnn(self._item_features)
    item_tower_emb = tf.layers.dense(
        inputs=item_tower_emb,
        units=last_item_hidden,
        kernel_regularizer=self._l2_reg,
        name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))

    assert self._model_config.simi_func in [
        Similarity.COSINE, Similarity.INNER_PRODUCT
    ]

    if self._model_config.simi_func == Similarity.COSINE:
      item_tower_emb = self.norm(item_tower_emb)
      user_interests = self.norm(user_interests)

    # label guided attention
    # attention item features on high capsules vector
    batch_size = tf.shape(user_interests)[0]
    pos_item_fea = item_tower_emb[:batch_size]
    simi = tf.einsum('bhe,be->bh', user_interests, pos_item_fea)
    tf.summary.histogram('interest_item_simi/pre_scale',
                         tf.reduce_max(simi, axis=1))
    # simi = tf.Print(simi, [tf.reduce_max(simi, axis=1), tf.reduce_min(simi, axis=1)], message='simi_max_0')
    # simi = tf.pow(simi, self._model_config.simi_pow)
    simi = simi * self._model_config.simi_pow
    tf.summary.histogram('interest_item_simi/scaled',
                         tf.reduce_max(simi, axis=1))
    # simi = tf.Print(simi, [tf.reduce_max(simi, axis=1), tf.reduce_min(simi, axis=1)], message='simi_max')
    simi_mask = tf.sequence_mask(num_high_capsules,
                                 self._model_config.capsule_config.max_k)

    user_interests = user_interests * tf.to_float(simi_mask[:, :, None])
    self._prediction_dict['user_interests'] = user_interests

    max_thresh = (tf.cast(simi_mask, tf.float32) * 2 - 1) * 1e32
    simi = tf.minimum(simi, max_thresh)
    simi = tf.nn.softmax(simi, axis=1)
    tf.summary.histogram('interest_item_simi/softmax',
                         tf.reduce_max(simi, axis=1))

    if self._model_config.simi_pow >= 100:
      logging.info(
          'simi_pow=%d, will change to argmax, only use the most similar interests for calculate loss.'
          % self._model_config.simi_pow)
      simi_max_id = tf.argmax(simi, axis=1)
      simi = tf.one_hot(simi_max_id, tf.shape(simi)[1], dtype=tf.float32)

    user_tower_emb = tf.einsum('bhe,bh->be', user_interests, simi)

    # calculate similarity between user_tower_emb and item_tower_emb
    user_item_sim = self.sim(user_tower_emb, item_tower_emb)
    if self._model_config.scale_simi:
      sim_w = tf.get_variable(
          'sim_w',
          dtype=tf.float32,
          shape=(1),
          initializer=tf.ones_initializer())
      sim_b = tf.get_variable(
          'sim_b',
          dtype=tf.float32,
          shape=(1),
          initializer=tf.zeros_initializer())
      y_pred = user_item_sim * tf.abs(sim_w) + sim_b
    else:
      y_pred = user_item_sim

    if self._is_point_wise:
      y_pred = tf.reshape(y_pred, [-1])

    if self._loss_type == LossType.CLASSIFICATION:
      self._prediction_dict['logits'] = y_pred
      self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
    elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
      y_pred = self._mask_in_batch(y_pred)
      self._prediction_dict['logits'] = y_pred
      self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
    else:
      self._prediction_dict['y'] = y_pred

    self._prediction_dict['high_capsules'] = high_capsules
    self._prediction_dict['user_interests'] = user_interests
    self._prediction_dict['user_tower_emb'] = user_tower_emb
    self._prediction_dict['item_tower_emb'] = item_tower_emb
    self._prediction_dict['user_emb'] = tf.reduce_join(
        tf.reduce_join(tf.as_string(user_interests), axis=-1, separator=','),
        axis=-1,
        separator='|')
    self._prediction_dict['user_emb_num'] = num_high_capsules
    self._prediction_dict['item_emb'] = tf.reduce_join(
        tf.as_string(item_tower_emb), axis=-1, separator=',')

    if self._labels is not None:
      # for summary purpose
      batch_simi, batch_capsule_simi = self._build_interest_simi()
      # self._prediction_dict['probs'] = tf.Print(self._prediction_dict['probs'],
      #     [batch_simi, batch_capsule_simi], message='batch_simi')
      self._prediction_dict['interests_simi'] = batch_simi
    return self._prediction_dict