def _internal_input_layer()

in easy_rec/python/compat/feature_column/feature_column.py [0:0]


def _internal_input_layer(features,
                          feature_columns,
                          weight_collections=None,
                          trainable=True,
                          cols_to_vars=None,
                          scope=None,
                          cols_to_output_tensors=None,
                          from_template=False,
                          feature_name_to_output_tensors=None,
                          is_training=True):
  """See input_layer, `scope` is a name or variable scope to use."""
  feature_columns = _normalize_feature_columns(feature_columns)
  for column in feature_columns:
    if not isinstance(column, _DenseColumn):
      raise ValueError(
          'Items of feature_columns must be a _DenseColumn. '
          'You can wrap a categorical column with an '
          'embedding_column or indicator_column. Given: {}'.format(column))
  weight_collections = list(weight_collections or [])
  if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
    weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
  if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
    weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)

  def _get_logits():  # pylint: disable=missing-docstring
    builder = _LazyBuilder(features)
    output_tensors = []

    tmp_cols = feature_columns
    if embedding_utils.sort_col_by_name():
      logging.info('will sort columns[len=%d] by name' % len(tmp_cols))
      tmp_cols = sorted(tmp_cols, key=lambda x: x.name)
    for column in tmp_cols:
      with variable_scope.variable_scope(
          None, default_name=column._var_scope_name):  # pylint: disable=protected-access
        tensor = column._get_dense_tensor(  # pylint: disable=protected-access
            builder,
            weight_collections=weight_collections,
            trainable=trainable)
        num_elements = column._variable_shape.num_elements()  # pylint: disable=protected-access
        batch_size = array_ops.shape(tensor)[0]
        output_tensor = array_ops.reshape(
            tensor, shape=(batch_size, num_elements))
        output_tensors.append(output_tensor)
        if cols_to_vars is not None:
          # Retrieve any variables created (some _DenseColumn's don't create
          # variables, in which case an empty list is returned).
          cols_to_vars[column] = ops.get_collection(
              ops.GraphKeys.GLOBAL_VARIABLES,
              scope=variable_scope.get_variable_scope().name)
        if cols_to_output_tensors is not None:
          cols_to_output_tensors[column] = output_tensor
        if feature_name_to_output_tensors is not None:
          feature_name_to_output_tensors[column.raw_name] = output_tensor
    return array_ops.concat(output_tensors, 1)

  def _get_logits_embedding_parallel():  # pylint: disable=missing-docstring
    assert hvd is not None, 'horovod is not installed'
    builder = _LazyBuilder(features)

    if embedding_utils.embedding_on_cpu():
      embedding_device = '/cpu:0'
    else:
      embedding_device = '/gpu:0'

    def _get_var_type(column):
      if column.ev_params.use_cache:
        return 'hybrid'
      else:
        return None

    output_tensors = []
    ordered_columns = []

    lookup_embeddings = []
    lookup_indices = None
    lookup_combiners = []
    lookup_cols = []
    lookup_output_ids = []
    lookup_wgts = []

    dense_cols = []
    dense_output_ids = []

    shared_weights = {}
    dense_cnt = 0

    batch_sizes = []
    for column in feature_columns:
      ordered_columns.append(column)
      with variable_scope.variable_scope(
          None, default_name=column._var_scope_name):  # pylint: disable=protected-access
        # for features which does not require embedding
        if 'Embedding' not in str(type(column)):
          dense_cols.append(column)
          dense_output_ids.append(len(output_tensors))
          output_tensors.append(None)
          dense_cnt += 1
          continue

        # for features require embedding
        num_buckets = column.categorical_column.num_buckets + hvd.size() - 1
        per_worker_buckets = num_buckets // hvd.size()
        embedding_shape = (per_worker_buckets, column.dimension)
        if 'SharedEmbedding' in str(type(column)):
          shared_name = column.shared_embedding_collection_name
          if shared_name in shared_weights:
            embedding_weights = shared_weights[shared_name]
          else:
            with ops.device(embedding_device):
              if column.ev_params is not None:
                assert dynamic_variable is not None, 'sok is not installed'
                embedding_weights = dynamic_variable.DynamicVariable(
                    name='embedding_weights',
                    dimension=column.dimension,
                    initializer='random {"stddev":0.0025}',  # column.initializer,
                    var_type=_get_var_type(column),
                    trainable=column.trainable and trainable,
                    dtype=dtypes.float32,
                    init_capacity=column.ev_params.init_capacity,
                    max_capacity=column.ev_params.max_capacity)
              else:
                embedding_weights = variable_scope.get_variable(
                    name='embedding_weights',
                    shape=embedding_shape,
                    dtype=dtypes.float32,
                    initializer=column.initializer,
                    trainable=column.trainable and trainable,
                    partitioner=None,
                    collections=weight_collections)
            shared_weights[shared_name] = embedding_weights
        else:
          with ops.device(embedding_device):
            if column.ev_params is not None:
              assert dynamic_variable is not None, 'sok is not installed'
              embedding_weights = dynamic_variable.DynamicVariable(
                  name='embedding_weights',
                  dimension=column.dimension,
                  initializer='random {"stddev":0.0025}',  # column.initializer,
                  var_type=_get_var_type(column),
                  trainable=column.trainable and trainable,
                  dtype=dtypes.float32,
                  init_capacity=column.ev_params.init_capacity,
                  max_capacity=column.ev_params.max_capacity)
            else:
              embedding_weights = variable_scope.get_variable(
                  name='embedding_weights',
                  shape=embedding_shape,
                  dtype=dtypes.float32,
                  initializer=column.initializer,
                  trainable=column.trainable and trainable,
                  partitioner=None,
                  collections=weight_collections)
        lookup_embeddings.append(embedding_weights)
        output_id = len(output_tensors)
        output_tensors.append(None)
        lookup_output_ids.append(output_id)
        lookup_cols.append(column)
        lookup_combiners.append(column.combiner)

        # SparseTensor RaggedTensor
        # features are not gathered into one, may have
        # performance issues
        if 'sparse_fea' in features.keys():
          if lookup_indices is None:
            lookup_indices = {'sparse_fea': features['sparse_fea']}
        elif 'ragged_ids' in features.keys():
          if lookup_indices is None:
            lookup_indices = {
                'ragged_ids': features['ragged_ids'],
                'ragged_lens': features['ragged_lens']
            }
            if 'ragged_wgts' in features:
              lookup_indices['ragged_wgts'] = features['ragged_wgts']
        else:
          if lookup_indices is None:
            lookup_indices = []
          with ops.device('/cpu:0'):
            sparse_tensors = column.categorical_column._get_sparse_tensors(
                builder,
                weight_collections=weight_collections,
                trainable=trainable)
            lookup_indices.append(sparse_tensors.id_tensor)
          if sparse_tensors.weight_tensor is not None:
            lookup_wgts.append(sparse_tensors.weight_tensor)
        if cols_to_vars is not None:
          cols_to_vars[column] = ops.get_collection(
              ops.GraphKeys.GLOBAL_VARIABLES,
              scope=variable_scope.get_variable_scope().name)

    if dense_cnt > 0:
      if 'dense_fea' in features:
        fea_dim_s = 0
        for dense_output_id, dense_col in zip(dense_output_ids, dense_cols):
          fea_dim_e = fea_dim_s + dense_col.shape[0]
          output_tensors[dense_output_id] = features[
              'dense_fea'][:, fea_dim_s:fea_dim_e]
          fea_dim_s = fea_dim_e
        batch_sizes.append(array_ops.shape(features['dense_fea'])[0])
      else:
        for dense_output_id, dense_col in zip(dense_output_ids, dense_cols):
          output_tensors[dense_output_id] = features[dense_col.raw_name]
        batch_sizes.append(array_ops.shape(output_tensors[dense_output_id])[0])

    for tmp_embed_var in set(lookup_embeddings):
      ops.add_to_collection(constant.EmbeddingParallel, tmp_embed_var.name)

    if len(batch_sizes) == 0:
      batch_size = None
    else:
      batch_size = batch_sizes[0]
    # do embedding parallel lookup
    if len(lookup_output_ids) > 0:
      packed_input = ('sparse_fea' in features or 'ragged_ids' in features)
      if packed_input:
        uniq_embed_cnt = len(set(lookup_embeddings))
        assert uniq_embed_cnt == 1, 'only one uniq embed is support for packed inputs'
        outputs = embedding_parallel_lookup(lookup_embeddings[0],
                                            lookup_indices, lookup_output_ids,
                                            is_training, output_tensors,
                                            batch_size)
      else:
        if batch_size is None:
          all_indices = []
          for lookup_indice in lookup_indices:
            all_indices.append(lookup_indice.indices[-1:, 0])
          all_indices = array_ops.concat(all_indices, axis=0)
          batch_size = math_ops.reduce_max(all_indices) + 1
        # group lookup_embeddings
        grouped_inputs = {}
        for embedding, lookup_indice, output_id in zip(lookup_embeddings,
                                                       lookup_indices,
                                                       lookup_output_ids):
          if embedding not in grouped_inputs:
            grouped_inputs[embedding] = {
                'lookup_indice': [lookup_indice],
                'output_id': [output_id]
            }
          else:
            grouped_inputs[embedding]['lookup_indice'].append(lookup_indice)
            grouped_inputs[embedding]['output_id'].append(output_id)

        for embedding in grouped_inputs:
          lookup_indices = grouped_inputs[embedding]['lookup_indice']
          output_ids = grouped_inputs[embedding]['output_id']
          outputs = embedding_parallel_lookup(embedding, lookup_indices,
                                              output_ids, is_training,
                                              output_tensors, batch_size)

      for output_tensor, col in zip(output_tensors, feature_columns):
        if feature_name_to_output_tensors is not None:
          feature_name_to_output_tensors[col.raw_name] = output_tensor
        if cols_to_output_tensors is not None:
          cols_to_output_tensors[col] = output_tensor

      if packed_input and dense_cnt == 0:
        return outputs
      else:
        return array_ops.concat(output_tensors, axis=1)
    else:
      for output_tensor, col in zip(output_tensors, feature_columns):
        if feature_name_to_output_tensors is not None:
          feature_name_to_output_tensors[col.raw_name] = output_tensor
        if cols_to_output_tensors is not None:
          cols_to_output_tensors[col] = output_tensor
      return array_ops.concat(output_tensors, axis=1)

  # If we're constructing from the `make_template`, that by default adds a
  # variable scope with the name of the layer. In that case, we dont want to
  # add another `variable_scope` as that would break checkpoints.
  if from_template:
    return _get_logits()
  else:
    with variable_scope.variable_scope(
        scope, default_name='input_layer', values=features.values()):
      if embedding_utils.is_embedding_parallel():
        return _get_logits_embedding_parallel()
      else:
        with conditional(embedding_utils.embedding_on_cpu(),
                         ops.device('/cpu:0')):
          return _get_logits()