def call_layer()

in easy_rec/python/layers/backbone.py [0:0]


  def call_layer(self, inputs, config, name, training, **kwargs):
    layer_name = config.WhichOneof('layer')
    if layer_name == 'keras_layer':
      return self.call_keras_layer(inputs, name, training, **kwargs)
    if layer_name == 'lambda':
      conf = getattr(config, 'lambda')
      fn = eval(conf.expression)
      return fn(inputs)
    if layer_name == 'repeat':
      conf = config.repeat
      n_loop = conf.num_repeat
      outputs = []
      for i in range(n_loop):
        name_i = '%s_%d' % (name, i)
        ly_inputs = inputs
        if conf.HasField('input_slice'):
          fn = eval('lambda x, i: x' + conf.input_slice.strip())
          ly_inputs = fn(ly_inputs, i)
        if conf.HasField('input_fn'):
          with tf.name_scope(config.name):
            fn = eval(conf.input_fn)
            ly_inputs = fn(ly_inputs, i)
        output = self.call_keras_layer(ly_inputs, name_i, training, **kwargs)
        outputs.append(output)
      if len(outputs) == 1:
        return outputs[0]
      if conf.HasField('output_concat_axis'):
        return tf.concat(outputs, conf.output_concat_axis)
      return outputs
    if layer_name == 'recurrent':
      conf = config.recurrent
      fixed_input_index = -1
      if conf.HasField('fixed_input_index'):
        fixed_input_index = conf.fixed_input_index
      if fixed_input_index >= 0:
        assert type(inputs) in (tuple, list), '%s inputs must be a list'
      output = inputs
      for i in range(conf.num_steps):
        name_i = '%s_%d' % (name, i)
        output_i = self.call_keras_layer(output, name_i, training, **kwargs)
        if fixed_input_index >= 0:
          j = 0
          for idx in range(len(output)):
            if idx == fixed_input_index:
              continue
            if type(output_i) in (tuple, list):
              output[idx] = output_i[j]
            else:
              output[idx] = output_i
            j += 1
        else:
          output = output_i
      if fixed_input_index >= 0:
        del output[fixed_input_index]
        if len(output) == 1:
          return output[0]
        return output
      return output

    raise NotImplementedError('Unsupported backbone layer:' + layer_name)