def call()

in easy_rec/python/layers/common_layers.py [0:0]


  def call(self, config, training):
    features, feature_list = self.inputs
    num_features = len(feature_list)

    do_ln = config.do_layer_norm
    do_bn = config.do_batch_norm
    do_feature_dropout = training and 0.0 < config.feature_dropout_rate < 1.0
    if do_feature_dropout:
      keep_prob = 1.0 - config.feature_dropout_rate
      mask = self.bern.sample(num_features)
    elif do_bn:
      features = tf.layers.batch_normalization(
          features, training=training, reuse=self._reuse)
    elif do_ln:
      features = layer_norm(
          features, name=self._group_name + '_features', reuse=self._reuse)

    output_feature_list = config.output_2d_tensor_and_feature_list
    output_feature_list = output_feature_list or config.only_output_feature_list
    output_feature_list = output_feature_list or config.only_output_3d_tensor
    rate = config.dropout_rate
    do_dropout = 0.0 < rate < 1.0
    if do_feature_dropout or do_ln or do_bn or do_dropout:
      for i in range(num_features):
        fea = feature_list[i]
        if do_bn:
          fea = tf.layers.batch_normalization(
              fea, training=training, reuse=self._reuse)
        elif do_ln:
          ln_name = self._group_name + 'f_%d' % i
          fea = layer_norm(fea, name=ln_name, reuse=self._reuse)
        if do_dropout and output_feature_list:
          fea = self.dropout.call(fea, training=training)
        if do_feature_dropout:
          fea = tf.div(fea, keep_prob) * mask[i]
        feature_list[i] = fea
      if do_feature_dropout:
        features = tf.concat(feature_list, axis=-1)

    if do_dropout and not do_feature_dropout:
      features = self.dropout.call(features, training=training)
    if features.shape.ndims == 3 and int(features.shape[0]) == 1:
      features = tf.squeeze(features, axis=0)

    if config.only_output_feature_list:
      return feature_list
    if config.only_output_3d_tensor:
      return tf.stack(feature_list, axis=1)
    if config.output_2d_tensor_and_feature_list:
      return features, feature_list
    return features