easy_rec/python/layers/backbone.py (511 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import six
import tensorflow as tf
from google.protobuf import struct_pb2
from easy_rec.python.layers.common_layers import EnhancedInputLayer
from easy_rec.python.layers.keras import MLP
from easy_rec.python.layers.keras import EmbeddingLayer
from easy_rec.python.layers.utils import Parameter
from easy_rec.python.protos import backbone_pb2
from easy_rec.python.utils.dag import DAG
from easy_rec.python.utils.load_class import load_keras_layer
from easy_rec.python.utils.tf_utils import add_elements_to_collection
if tf.__version__ >= '2.0':
tf = tf.compat.v1
class Package(object):
"""A sub DAG of tf ops for reuse."""
__packages = {}
@staticmethod
def has_backbone_block(name):
if 'backbone' not in Package.__packages:
return False
backbone = Package.__packages['backbone']
return backbone.has_block(name)
@staticmethod
def backbone_block_outputs(name):
if 'backbone' not in Package.__packages:
return None
backbone = Package.__packages['backbone']
return backbone.block_outputs(name)
def __init__(self, config, features, input_layer, l2_reg=None):
self._config = config
self._features = features
self._input_layer = input_layer
self._l2_reg = l2_reg
self._dag = DAG()
self._name_to_blocks = {}
self._name_to_layer = {}
self.reset_input_config(None)
self._block_outputs = {}
self._package_input = None
self._feature_group_inputs = {}
reuse = None if config.name == 'backbone' else tf.AUTO_REUSE
input_feature_groups = self._feature_group_inputs
for block in config.blocks:
if len(block.inputs) == 0:
raise ValueError('block takes at least one input: %s' % block.name)
self._dag.add_node(block.name)
self._name_to_blocks[block.name] = block
layer = block.WhichOneof('layer')
if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
if len(block.inputs) != 1:
raise ValueError('input layer `%s` takes only one input' % block.name)
one_input = block.inputs[0]
name = one_input.WhichOneof('name')
if name != 'feature_group_name':
raise KeyError(
'`feature_group_name` should be set for input layer: ' +
block.name)
group = one_input.feature_group_name
if not input_layer.has_group(group):
raise KeyError('invalid feature group name: ' + group)
if group in input_feature_groups:
if layer == input_layer:
logging.warning('input `%s` already exists in other block' % group)
elif layer == 'raw_input':
input_fn = input_feature_groups[group]
self._name_to_layer[block.name] = input_fn
elif layer == 'embedding_layer':
inputs, vocab, weights = input_feature_groups[group]
block.embedding_layer.vocab_size = vocab
params = Parameter.make_from_pb(block.embedding_layer)
input_fn = EmbeddingLayer(params, block.name)
self._name_to_layer[block.name] = input_fn
else:
if layer == 'input_layer':
input_fn = EnhancedInputLayer(self._input_layer, self._features,
group, reuse)
input_feature_groups[group] = input_fn
elif layer == 'raw_input':
input_fn = self._input_layer.get_raw_features(self._features, group)
input_feature_groups[group] = input_fn
else: # embedding_layer
inputs, vocab, weights = self._input_layer.get_bucketized_features(
self._features, group)
block.embedding_layer.vocab_size = vocab
params = Parameter.make_from_pb(block.embedding_layer)
input_fn = EmbeddingLayer(params, block.name)
input_feature_groups[group] = (inputs, vocab, weights)
logging.info('add an embedding layer %s with vocab size %d',
block.name, vocab)
self._name_to_layer[block.name] = input_fn
else:
self.define_layers(layer, block, block.name, reuse)
# sequential layers
for i, layer_cnf in enumerate(block.layers):
layer = layer_cnf.WhichOneof('layer')
name_i = '%s_l%d' % (block.name, i)
self.define_layers(layer, layer_cnf, name_i, reuse)
num_groups = len(input_feature_groups)
num_blocks = len(self._name_to_blocks) - num_groups
assert num_blocks > 0, 'there must be at least one block in backbone'
num_pkg_input = 0
for block in config.blocks:
layer = block.WhichOneof('layer')
if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
continue
name = block.name
if name in input_feature_groups:
raise KeyError('block name can not be one of feature groups:' + name)
for input_node in block.inputs:
input_type = input_node.WhichOneof('name')
input_name = getattr(input_node, input_type)
if input_type == 'use_package_input':
assert input_name, 'use_package_input can not set false'
num_pkg_input += 1
continue
if input_type == 'package_name':
num_pkg_input += 1
self._dag.add_node_if_not_exists(input_name)
self._dag.add_edge(input_name, name)
if input_node.HasField('package_input'):
pkg_input_name = input_node.package_input
self._dag.add_node_if_not_exists(pkg_input_name)
self._dag.add_edge(pkg_input_name, input_name)
continue
iname = input_name
if iname in self._name_to_blocks:
assert iname != name, 'input name can not equal to block name:' + iname
self._dag.add_edge(iname, name)
else:
is_fea_group = input_type == 'feature_group_name'
if is_fea_group and input_layer.has_group(iname):
logging.info('adding an input_layer block: ' + iname)
new_block = backbone_pb2.Block()
new_block.name = iname
input_cfg = backbone_pb2.Input()
input_cfg.feature_group_name = iname
new_block.inputs.append(input_cfg)
new_block.input_layer.CopyFrom(backbone_pb2.InputLayer())
self._name_to_blocks[iname] = new_block
self._dag.add_node(iname)
self._dag.add_edge(iname, name)
if iname in input_feature_groups:
fn = input_feature_groups[iname]
else:
fn = EnhancedInputLayer(self._input_layer, self._features, iname)
input_feature_groups[iname] = fn
self._name_to_layer[iname] = fn
elif Package.has_backbone_block(iname):
backbone = Package.__packages['backbone']
backbone._dag.add_node_if_not_exists(self._config.name)
backbone._dag.add_edge(iname, self._config.name)
num_pkg_input += 1
else:
raise KeyError(
'invalid input name `%s`, must be the name of either a feature group or an another block'
% iname)
num_groups = len(input_feature_groups)
assert num_pkg_input > 0 or num_groups > 0, 'there must be at least one input layer/feature group'
if len(config.concat_blocks) == 0 and len(config.output_blocks) == 0:
leaf = self._dag.all_leaves()
logging.warning(
'%s has no `concat_blocks` or `output_blocks`, try to concat all leaf blocks: %s'
% (config.name, ','.join(leaf)))
self._config.concat_blocks.extend(leaf)
Package.__packages[self._config.name] = self
logging.info('%s layers: %s' %
(config.name, ','.join(self._name_to_layer.keys())))
def define_layers(self, layer, layer_cnf, name, reuse):
if layer == 'keras_layer':
layer_obj = self.load_keras_layer(layer_cnf.keras_layer, name, reuse)
self._name_to_layer[name] = layer_obj
elif layer == 'recurrent':
keras_layer = layer_cnf.recurrent.keras_layer
for i in range(layer_cnf.recurrent.num_steps):
name_i = '%s_%d' % (name, i)
layer_obj = self.load_keras_layer(keras_layer, name_i, reuse)
self._name_to_layer[name_i] = layer_obj
elif layer == 'repeat':
keras_layer = layer_cnf.repeat.keras_layer
for i in range(layer_cnf.repeat.num_repeat):
name_i = '%s_%d' % (name, i)
layer_obj = self.load_keras_layer(keras_layer, name_i, reuse)
self._name_to_layer[name_i] = layer_obj
def reset_input_config(self, config):
self.input_config = config
def set_package_input(self, pkg_input):
self._package_input = pkg_input
def has_block(self, name):
return name in self._name_to_blocks
def block_outputs(self, name):
return self._block_outputs.get(name, None)
def block_input(self, config, block_outputs, training=None, **kwargs):
inputs = []
for input_node in config.inputs:
input_type = input_node.WhichOneof('name')
input_name = getattr(input_node, input_type)
if input_type == 'use_package_input':
input_feature = self._package_input
input_name = 'package_input'
elif input_type == 'package_name':
if input_name not in Package.__packages:
raise KeyError('package name `%s` does not exists' % input_name)
package = Package.__packages[input_name]
if input_node.HasField('reset_input'):
package.reset_input_config(input_node.reset_input)
if input_node.HasField('package_input'):
pkg_input_name = input_node.package_input
if pkg_input_name in block_outputs:
pkg_input = block_outputs[pkg_input_name]
else:
if pkg_input_name not in Package.__packages:
raise KeyError('package name `%s` does not exists' %
pkg_input_name)
inner_package = Package.__packages[pkg_input_name]
pkg_input = inner_package(training)
if input_node.HasField('package_input_fn'):
fn = eval(input_node.package_input_fn)
pkg_input = fn(pkg_input)
package.set_package_input(pkg_input)
input_feature = package(training, **kwargs)
elif input_name in block_outputs:
input_feature = block_outputs[input_name]
else:
input_feature = Package.backbone_block_outputs(input_name)
if input_feature is None:
raise KeyError('input name `%s` does not exists' % input_name)
if input_node.ignore_input:
continue
if input_node.HasField('input_slice'):
fn = eval('lambda x: x' + input_node.input_slice.strip())
input_feature = fn(input_feature)
if input_node.HasField('input_fn'):
with tf.name_scope(config.name):
fn = eval(input_node.input_fn)
input_feature = fn(input_feature)
inputs.append(input_feature)
if config.merge_inputs_into_list:
output = inputs
else:
try:
output = merge_inputs(inputs, config.input_concat_axis, config.name)
except ValueError as e:
msg = getattr(e, 'message', str(e))
logging.error('merge inputs of block %s failed: %s', config.name, msg)
raise e
if config.HasField('extra_input_fn'):
fn = eval(config.extra_input_fn)
output = fn(output)
return output
def __call__(self, is_training, **kwargs):
with tf.name_scope(self._config.name):
return self.call(is_training, **kwargs)
def call(self, is_training, **kwargs):
block_outputs = {}
self._block_outputs = block_outputs # reset
blocks = self._dag.topological_sort()
logging.info(self._config.name + ' topological order: ' + ','.join(blocks))
for block in blocks:
if block not in self._name_to_blocks:
assert block in Package.__packages, 'invalid block: ' + block
continue
config = self._name_to_blocks[block]
if config.layers: # sequential layers
logging.info('call sequential %d layers' % len(config.layers))
output = self.block_input(config, block_outputs, is_training, **kwargs)
for i, layer in enumerate(config.layers):
name_i = '%s_l%d' % (block, i)
output = self.call_layer(output, layer, name_i, is_training, **kwargs)
block_outputs[block] = output
continue
# just one of layer
layer = config.WhichOneof('layer')
if layer is None: # identity layer
output = self.block_input(config, block_outputs, is_training, **kwargs)
block_outputs[block] = output
elif layer == 'raw_input':
block_outputs[block] = self._name_to_layer[block]
elif layer == 'input_layer':
input_fn = self._name_to_layer[block]
input_config = config.input_layer
if self.input_config is not None:
input_config = self.input_config
input_fn.reset(input_config, is_training)
block_outputs[block] = input_fn(input_config, is_training)
elif layer == 'embedding_layer':
input_fn = self._name_to_layer[block]
feature_group = config.inputs[0].feature_group_name
inputs, _, weights = self._feature_group_inputs[feature_group]
block_outputs[block] = input_fn([inputs, weights], is_training)
else:
with tf.name_scope(block + '_input'):
inputs = self.block_input(config, block_outputs, is_training,
**kwargs)
output = self.call_layer(inputs, config, block, is_training, **kwargs)
block_outputs[block] = output
outputs = []
for output in self._config.output_blocks:
if output in block_outputs:
temp = block_outputs[output]
outputs.append(temp)
else:
raise ValueError('No output `%s` of backbone to be concat' % output)
if outputs:
return outputs
for output in self._config.concat_blocks:
if output in block_outputs:
temp = block_outputs[output]
outputs.append(temp)
else:
raise ValueError('No output `%s` of backbone to be concat' % output)
try:
output = merge_inputs(outputs, msg='backbone')
except ValueError as e:
msg = getattr(e, 'message', str(e))
logging.error("merge backbone's output failed: %s", msg)
raise e
return output
def load_keras_layer(self, layer_conf, name, reuse=None):
layer_cls, customize = load_keras_layer(layer_conf.class_name)
if layer_cls is None:
raise ValueError('Invalid keras layer class name: ' +
layer_conf.class_name)
param_type = layer_conf.WhichOneof('params')
if customize:
if param_type is None or param_type == 'st_params':
params = Parameter(layer_conf.st_params, True, l2_reg=self._l2_reg)
else:
pb_params = getattr(layer_conf, param_type)
params = Parameter(pb_params, False, l2_reg=self._l2_reg)
has_reuse = True
try:
from funcsigs import signature
sig = signature(layer_cls.__init__)
has_reuse = 'reuse' in sig.parameters.keys()
except ImportError:
try:
from sklearn.externals.funcsigs import signature
sig = signature(layer_cls.__init__)
has_reuse = 'reuse' in sig.parameters.keys()
except ImportError:
logging.warning('import funcsigs failed')
if has_reuse:
layer = layer_cls(params, name=name, reuse=reuse)
else:
layer = layer_cls(params, name=name)
return layer, customize
elif param_type is None: # internal keras layer
layer = layer_cls(name=name)
return layer, customize
else:
assert param_type == 'st_params', 'internal keras layer only support st_params'
try:
kwargs = convert_to_dict(layer_conf.st_params)
logging.info('call %s layer with params %r' %
(layer_conf.class_name, kwargs))
layer = layer_cls(name=name, **kwargs)
except TypeError as e:
logging.warning(e)
args = map(format_value, layer_conf.st_params.values())
logging.info('try to call %s layer with params %r' %
(layer_conf.class_name, args))
layer = layer_cls(*args, name=name)
return layer, customize
def call_keras_layer(self, inputs, name, training, **kwargs):
"""Call predefined Keras Layer, which can be reused."""
layer, customize = self._name_to_layer[name]
cls = layer.__class__.__name__
if customize:
try:
output = layer(inputs, training=training, **kwargs)
except Exception as e:
msg = getattr(e, 'message', str(e))
logging.error('call keras layer %s (%s) failed: %s' % (name, cls, msg))
raise e
else:
try:
output = layer(inputs, training=training)
if cls == 'BatchNormalization':
add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
except TypeError:
output = layer(inputs)
return output
def call_layer(self, inputs, config, name, training, **kwargs):
layer_name = config.WhichOneof('layer')
if layer_name == 'keras_layer':
return self.call_keras_layer(inputs, name, training, **kwargs)
if layer_name == 'lambda':
conf = getattr(config, 'lambda')
fn = eval(conf.expression)
return fn(inputs)
if layer_name == 'repeat':
conf = config.repeat
n_loop = conf.num_repeat
outputs = []
for i in range(n_loop):
name_i = '%s_%d' % (name, i)
ly_inputs = inputs
if conf.HasField('input_slice'):
fn = eval('lambda x, i: x' + conf.input_slice.strip())
ly_inputs = fn(ly_inputs, i)
if conf.HasField('input_fn'):
with tf.name_scope(config.name):
fn = eval(conf.input_fn)
ly_inputs = fn(ly_inputs, i)
output = self.call_keras_layer(ly_inputs, name_i, training, **kwargs)
outputs.append(output)
if len(outputs) == 1:
return outputs[0]
if conf.HasField('output_concat_axis'):
return tf.concat(outputs, conf.output_concat_axis)
return outputs
if layer_name == 'recurrent':
conf = config.recurrent
fixed_input_index = -1
if conf.HasField('fixed_input_index'):
fixed_input_index = conf.fixed_input_index
if fixed_input_index >= 0:
assert type(inputs) in (tuple, list), '%s inputs must be a list'
output = inputs
for i in range(conf.num_steps):
name_i = '%s_%d' % (name, i)
output_i = self.call_keras_layer(output, name_i, training, **kwargs)
if fixed_input_index >= 0:
j = 0
for idx in range(len(output)):
if idx == fixed_input_index:
continue
if type(output_i) in (tuple, list):
output[idx] = output_i[j]
else:
output[idx] = output_i
j += 1
else:
output = output_i
if fixed_input_index >= 0:
del output[fixed_input_index]
if len(output) == 1:
return output[0]
return output
return output
raise NotImplementedError('Unsupported backbone layer:' + layer_name)
class Backbone(object):
"""Configurable Backbone Network."""
def __init__(self, config, features, input_layer, l2_reg=None):
self._config = config
self._l2_reg = l2_reg
main_pkg = backbone_pb2.BlockPackage()
main_pkg.name = 'backbone'
main_pkg.blocks.MergeFrom(config.blocks)
if config.concat_blocks:
main_pkg.concat_blocks.extend(config.concat_blocks)
if config.output_blocks:
main_pkg.output_blocks.extend(config.output_blocks)
self._main_pkg = Package(main_pkg, features, input_layer, l2_reg)
for pkg in config.packages:
Package(pkg, features, input_layer, l2_reg)
def __call__(self, is_training, **kwargs):
output = self._main_pkg(is_training, **kwargs)
if self._config.HasField('top_mlp'):
params = Parameter.make_from_pb(self._config.top_mlp)
params.l2_regularizer = self._l2_reg
final_mlp = MLP(params, name='backbone_top_mlp')
if type(output) in (list, tuple):
output = tf.concat(output, axis=-1)
output = final_mlp(output, training=is_training, **kwargs)
return output
@classmethod
def wide_embed_dim(cls, config):
wide_embed_dim = None
for pkg in config.packages:
wide_embed_dim = get_wide_embed_dim(pkg.blocks, wide_embed_dim)
return get_wide_embed_dim(config.blocks, wide_embed_dim)
def get_wide_embed_dim(blocks, wide_embed_dim=None):
for block in blocks:
layer = block.WhichOneof('layer')
if layer == 'input_layer':
if block.input_layer.HasField('wide_output_dim'):
wide_dim = block.input_layer.wide_output_dim
if wide_embed_dim:
assert wide_embed_dim == wide_dim, 'wide_output_dim must be consistent'
else:
wide_embed_dim = wide_dim
return wide_embed_dim
def merge_inputs(inputs, axis=-1, msg=''):
if len(inputs) == 0:
raise ValueError('no inputs to be concat:' + msg)
if len(inputs) == 1:
return inputs[0]
from functools import reduce
if all(map(lambda x: type(x) == list, inputs)):
# merge multiple lists into a list
return reduce(lambda x, y: x + y, inputs)
if any(map(lambda x: type(x) == list, inputs)):
logging.warning('%s: try to merge inputs into list' % msg)
return reduce(lambda x, y: x + y,
[e if type(e) == list else [e] for e in inputs])
if axis != -1:
logging.info('concat inputs %s axis=%d' % (msg, axis))
return tf.concat(inputs, axis=axis)
def format_value(value):
value_type = type(value)
if value_type == six.text_type:
return str(value)
if value_type == float:
int_v = int(value)
return int_v if int_v == value else value
if value_type == struct_pb2.ListValue:
return map(format_value, value)
if value_type == struct_pb2.Struct:
return convert_to_dict(value)
return value
def convert_to_dict(struct):
kwargs = {}
for key, value in struct.items():
kwargs[str(key)] = format_value(value)
return kwargs