def __init__()

in easy_rec/python/layers/backbone.py [0:0]


  def __init__(self, config, features, input_layer, l2_reg=None):
    self._config = config
    self._features = features
    self._input_layer = input_layer
    self._l2_reg = l2_reg
    self._dag = DAG()
    self._name_to_blocks = {}
    self._name_to_layer = {}
    self.reset_input_config(None)
    self._block_outputs = {}
    self._package_input = None
    self._feature_group_inputs = {}
    reuse = None if config.name == 'backbone' else tf.AUTO_REUSE
    input_feature_groups = self._feature_group_inputs

    for block in config.blocks:
      if len(block.inputs) == 0:
        raise ValueError('block takes at least one input: %s' % block.name)
      self._dag.add_node(block.name)
      self._name_to_blocks[block.name] = block
      layer = block.WhichOneof('layer')
      if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
        if len(block.inputs) != 1:
          raise ValueError('input layer `%s` takes only one input' % block.name)
        one_input = block.inputs[0]
        name = one_input.WhichOneof('name')
        if name != 'feature_group_name':
          raise KeyError(
              '`feature_group_name` should be set for input layer: ' +
              block.name)
        group = one_input.feature_group_name
        if not input_layer.has_group(group):
          raise KeyError('invalid feature group name: ' + group)
        if group in input_feature_groups:
          if layer == input_layer:
            logging.warning('input `%s` already exists in other block' % group)
          elif layer == 'raw_input':
            input_fn = input_feature_groups[group]
            self._name_to_layer[block.name] = input_fn
          elif layer == 'embedding_layer':
            inputs, vocab, weights = input_feature_groups[group]
            block.embedding_layer.vocab_size = vocab
            params = Parameter.make_from_pb(block.embedding_layer)
            input_fn = EmbeddingLayer(params, block.name)
            self._name_to_layer[block.name] = input_fn
        else:
          if layer == 'input_layer':
            input_fn = EnhancedInputLayer(self._input_layer, self._features,
                                          group, reuse)
            input_feature_groups[group] = input_fn
          elif layer == 'raw_input':
            input_fn = self._input_layer.get_raw_features(self._features, group)
            input_feature_groups[group] = input_fn
          else:  # embedding_layer
            inputs, vocab, weights = self._input_layer.get_bucketized_features(
                self._features, group)
            block.embedding_layer.vocab_size = vocab
            params = Parameter.make_from_pb(block.embedding_layer)
            input_fn = EmbeddingLayer(params, block.name)
            input_feature_groups[group] = (inputs, vocab, weights)
            logging.info('add an embedding layer %s with vocab size %d',
                         block.name, vocab)
          self._name_to_layer[block.name] = input_fn
      else:
        self.define_layers(layer, block, block.name, reuse)

      # sequential layers
      for i, layer_cnf in enumerate(block.layers):
        layer = layer_cnf.WhichOneof('layer')
        name_i = '%s_l%d' % (block.name, i)
        self.define_layers(layer, layer_cnf, name_i, reuse)

    num_groups = len(input_feature_groups)
    num_blocks = len(self._name_to_blocks) - num_groups
    assert num_blocks > 0, 'there must be at least one block in backbone'

    num_pkg_input = 0
    for block in config.blocks:
      layer = block.WhichOneof('layer')
      if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
        continue
      name = block.name
      if name in input_feature_groups:
        raise KeyError('block name can not be one of feature groups:' + name)
      for input_node in block.inputs:
        input_type = input_node.WhichOneof('name')
        input_name = getattr(input_node, input_type)
        if input_type == 'use_package_input':
          assert input_name, 'use_package_input can not set false'
          num_pkg_input += 1
          continue
        if input_type == 'package_name':
          num_pkg_input += 1
          self._dag.add_node_if_not_exists(input_name)
          self._dag.add_edge(input_name, name)
          if input_node.HasField('package_input'):
            pkg_input_name = input_node.package_input
            self._dag.add_node_if_not_exists(pkg_input_name)
            self._dag.add_edge(pkg_input_name, input_name)
          continue
        iname = input_name
        if iname in self._name_to_blocks:
          assert iname != name, 'input name can not equal to block name:' + iname
          self._dag.add_edge(iname, name)
        else:
          is_fea_group = input_type == 'feature_group_name'
          if is_fea_group and input_layer.has_group(iname):
            logging.info('adding an input_layer block: ' + iname)
            new_block = backbone_pb2.Block()
            new_block.name = iname
            input_cfg = backbone_pb2.Input()
            input_cfg.feature_group_name = iname
            new_block.inputs.append(input_cfg)
            new_block.input_layer.CopyFrom(backbone_pb2.InputLayer())
            self._name_to_blocks[iname] = new_block
            self._dag.add_node(iname)
            self._dag.add_edge(iname, name)
            if iname in input_feature_groups:
              fn = input_feature_groups[iname]
            else:
              fn = EnhancedInputLayer(self._input_layer, self._features, iname)
              input_feature_groups[iname] = fn
            self._name_to_layer[iname] = fn
          elif Package.has_backbone_block(iname):
            backbone = Package.__packages['backbone']
            backbone._dag.add_node_if_not_exists(self._config.name)
            backbone._dag.add_edge(iname, self._config.name)
            num_pkg_input += 1
          else:
            raise KeyError(
                'invalid input name `%s`, must be the name of either a feature group or an another block'
                % iname)
    num_groups = len(input_feature_groups)
    assert num_pkg_input > 0 or num_groups > 0, 'there must be at least one input layer/feature group'

    if len(config.concat_blocks) == 0 and len(config.output_blocks) == 0:
      leaf = self._dag.all_leaves()
      logging.warning(
          '%s has no `concat_blocks` or `output_blocks`, try to concat all leaf blocks: %s'
          % (config.name, ','.join(leaf)))
      self._config.concat_blocks.extend(leaf)

    Package.__packages[self._config.name] = self
    logging.info('%s layers: %s' %
                 (config.name, ','.join(self._name_to_layer.keys())))