in easy_rec/python/layers/backbone.py [0:0]
def call_layer(self, inputs, config, name, training, **kwargs):
layer_name = config.WhichOneof('layer')
if layer_name == 'keras_layer':
return self.call_keras_layer(inputs, name, training, **kwargs)
if layer_name == 'lambda':
conf = getattr(config, 'lambda')
fn = eval(conf.expression)
return fn(inputs)
if layer_name == 'repeat':
conf = config.repeat
n_loop = conf.num_repeat
outputs = []
for i in range(n_loop):
name_i = '%s_%d' % (name, i)
ly_inputs = inputs
if conf.HasField('input_slice'):
fn = eval('lambda x, i: x' + conf.input_slice.strip())
ly_inputs = fn(ly_inputs, i)
if conf.HasField('input_fn'):
with tf.name_scope(config.name):
fn = eval(conf.input_fn)
ly_inputs = fn(ly_inputs, i)
output = self.call_keras_layer(ly_inputs, name_i, training, **kwargs)
outputs.append(output)
if len(outputs) == 1:
return outputs[0]
if conf.HasField('output_concat_axis'):
return tf.concat(outputs, conf.output_concat_axis)
return outputs
if layer_name == 'recurrent':
conf = config.recurrent
fixed_input_index = -1
if conf.HasField('fixed_input_index'):
fixed_input_index = conf.fixed_input_index
if fixed_input_index >= 0:
assert type(inputs) in (tuple, list), '%s inputs must be a list'
output = inputs
for i in range(conf.num_steps):
name_i = '%s_%d' % (name, i)
output_i = self.call_keras_layer(output, name_i, training, **kwargs)
if fixed_input_index >= 0:
j = 0
for idx in range(len(output)):
if idx == fixed_input_index:
continue
if type(output_i) in (tuple, list):
output[idx] = output_i[j]
else:
output[idx] = output_i
j += 1
else:
output = output_i
if fixed_input_index >= 0:
del output[fixed_input_index]
if len(output) == 1:
return output[0]
return output
return output
raise NotImplementedError('Unsupported backbone layer:' + layer_name)