def embedding_parallel_lookup()

in easy_rec/python/compat/feature_column/feature_column.py [0:0]


def embedding_parallel_lookup(embedding,
                              lookup_indices,
                              output_ids,
                              is_training,
                              output_tensors=None,
                              batch_size=None):
  N = len(output_ids)
  if batch_size is None:
    num_segments = None
  else:
    num_segments = N * batch_size
  # first concat all the ids and unique
  if isinstance(lookup_indices, dict) and 'sparse_fea' in lookup_indices.keys():
    # all_uniq_ids, uniq_idx, segment_lens = features['sparse_fea']
    all_ids, segment_lens = lookup_indices['sparse_fea']
    all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
    cumsum_lens = math_ops.cumsum(segment_lens)
    segment_ids = array_ops.searchsorted(
        cumsum_lens, math_ops.range(cumsum_lens[-1]), side='right')
  elif isinstance(lookup_indices, dict) and 'ragged_ids' in lookup_indices.keys(
  ) and 'ragged_lens' in lookup_indices.keys():
    all_ids, segment_lens = lookup_indices['ragged_ids'], lookup_indices[
        'ragged_lens']
    all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
    cumsum_lens = math_ops.cumsum(segment_lens)
    segment_ids = array_ops.searchsorted(
        cumsum_lens, math_ops.range(cumsum_lens[-1]), side='right')
  elif isinstance(lookup_indices[0], sparse_tensor_lib.SparseTensor):
    with ops.device('/cpu:0'):
      all_ids = array_ops.concat([x.values for x in lookup_indices], axis=0)
      segment_ids = array_ops.concat([x.indices[:, 0] for x in lookup_indices],
                                     axis=0)
    all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
  elif 'RaggedTensor' in str(type(lookup_indices[0])):
    with ops.device('/cpu:0'):
      all_ids = array_ops.concat([x.values for x in lookup_indices], axis=0)
      segment_lens = array_ops.concat([x.row_lengths() for x in lookup_indices],
                                      axis=0)
    all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
    cumsum_lens = math_ops.cumsum(segment_lens)
    segment_ids = array_ops.searchsorted(
        cumsum_lens, math_ops.range(cumsum_lens[-1]), side='right')
  else:
    assert False, 'invalid indices type: %s' % str(type(lookup_indices[0]))

  num_parts = hvd.size()
  if num_parts > 1:
    # dynamic partition
    p_assignments = math_ops.cast(all_uniq_ids % num_parts, dtypes.int32)
    gather_ids = data_flow_ops.dynamic_partition(all_uniq_ids, p_assignments,
                                                 num_parts)
    original_ids = math_ops.range(array_ops.size(all_uniq_ids))
    original_part_ids = data_flow_ops.dynamic_partition(original_ids,
                                                        p_assignments,
                                                        num_parts)
    # all2all
    split_sizes = array_ops.concat([array_ops.shape(x) for x in gather_ids],
                                   axis=0)
    send_ids = array_ops.concat(gather_ids, axis=0)
    recv_ids, recv_lens = hvd.alltoall(send_ids, split_sizes)

    # read embedding from dynamic variable
    if isinstance(embedding, dynamic_variable.DynamicVariable):
      send_embed = embedding.sparse_read(
          recv_ids, lookup_only=(not is_training))
    else:
      # find in subarray position
      # 0 2 4 6 8 10 ...
      # 1 3 5 7 9 11 ...
      recv_ids = math_ops.cast(recv_ids / num_parts, dtypes.int64)
      send_embed = array_ops.gather(embedding, recv_ids)

    # all2all
    recv_embeddings, _ = hvd.alltoall(send_embed, recv_lens)
    recv_embeddings = array_ops.split(
        recv_embeddings, num_or_size_splits=split_sizes)
    recv_embeddings = data_flow_ops.parallel_dynamic_stitch(
        original_part_ids, recv_embeddings, name='parallel_dynamic_stitch')
    embeddings = math_ops.sparse_segment_sum(
        recv_embeddings,
        uniq_idx,
        segment_ids,
        num_segments=num_segments,
        name='sparse_segment_sum')
  else:
    if isinstance(embedding, dynamic_variable.DynamicVariable):
      recv_embeddings = embedding.sparse_read(
          all_uniq_ids, lookup_only=(not is_training))
    else:
      recv_embeddings = array_ops.gather(embedding, all_uniq_ids)
    embeddings = math_ops.sparse_segment_sum(
        recv_embeddings,
        uniq_idx,
        segment_ids,
        num_segments=num_segments,
        name='sparse_segment_sum')

  embed_dim = embedding.get_shape()[-1]
  output_tensor = array_ops.reshape(embeddings, [N, -1, embed_dim])

  if output_tensors is not None:
    outputs = array_ops.split(output_tensor, num_or_size_splits=N, axis=0)
    for output, output_id in zip(outputs, output_ids):
      output_tensors[output_id] = array_ops.squeeze(output, axis=0)

  if batch_size is None:
    batch_size = -1
  return array_ops.reshape(
      array_ops.transpose(output_tensor, perm=[1, 0, 2]),
      [batch_size, N * embed_dim])