in easy_rec/python/compat/feature_column/feature_column.py [0:0]
def embedding_parallel_lookup(embedding,
lookup_indices,
output_ids,
is_training,
output_tensors=None,
batch_size=None):
N = len(output_ids)
if batch_size is None:
num_segments = None
else:
num_segments = N * batch_size
# first concat all the ids and unique
if isinstance(lookup_indices, dict) and 'sparse_fea' in lookup_indices.keys():
# all_uniq_ids, uniq_idx, segment_lens = features['sparse_fea']
all_ids, segment_lens = lookup_indices['sparse_fea']
all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
cumsum_lens = math_ops.cumsum(segment_lens)
segment_ids = array_ops.searchsorted(
cumsum_lens, math_ops.range(cumsum_lens[-1]), side='right')
elif isinstance(lookup_indices, dict) and 'ragged_ids' in lookup_indices.keys(
) and 'ragged_lens' in lookup_indices.keys():
all_ids, segment_lens = lookup_indices['ragged_ids'], lookup_indices[
'ragged_lens']
all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
cumsum_lens = math_ops.cumsum(segment_lens)
segment_ids = array_ops.searchsorted(
cumsum_lens, math_ops.range(cumsum_lens[-1]), side='right')
elif isinstance(lookup_indices[0], sparse_tensor_lib.SparseTensor):
with ops.device('/cpu:0'):
all_ids = array_ops.concat([x.values for x in lookup_indices], axis=0)
segment_ids = array_ops.concat([x.indices[:, 0] for x in lookup_indices],
axis=0)
all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
elif 'RaggedTensor' in str(type(lookup_indices[0])):
with ops.device('/cpu:0'):
all_ids = array_ops.concat([x.values for x in lookup_indices], axis=0)
segment_lens = array_ops.concat([x.row_lengths() for x in lookup_indices],
axis=0)
all_uniq_ids, uniq_idx = array_ops.unique(all_ids)
cumsum_lens = math_ops.cumsum(segment_lens)
segment_ids = array_ops.searchsorted(
cumsum_lens, math_ops.range(cumsum_lens[-1]), side='right')
else:
assert False, 'invalid indices type: %s' % str(type(lookup_indices[0]))
num_parts = hvd.size()
if num_parts > 1:
# dynamic partition
p_assignments = math_ops.cast(all_uniq_ids % num_parts, dtypes.int32)
gather_ids = data_flow_ops.dynamic_partition(all_uniq_ids, p_assignments,
num_parts)
original_ids = math_ops.range(array_ops.size(all_uniq_ids))
original_part_ids = data_flow_ops.dynamic_partition(original_ids,
p_assignments,
num_parts)
# all2all
split_sizes = array_ops.concat([array_ops.shape(x) for x in gather_ids],
axis=0)
send_ids = array_ops.concat(gather_ids, axis=0)
recv_ids, recv_lens = hvd.alltoall(send_ids, split_sizes)
# read embedding from dynamic variable
if isinstance(embedding, dynamic_variable.DynamicVariable):
send_embed = embedding.sparse_read(
recv_ids, lookup_only=(not is_training))
else:
# find in subarray position
# 0 2 4 6 8 10 ...
# 1 3 5 7 9 11 ...
recv_ids = math_ops.cast(recv_ids / num_parts, dtypes.int64)
send_embed = array_ops.gather(embedding, recv_ids)
# all2all
recv_embeddings, _ = hvd.alltoall(send_embed, recv_lens)
recv_embeddings = array_ops.split(
recv_embeddings, num_or_size_splits=split_sizes)
recv_embeddings = data_flow_ops.parallel_dynamic_stitch(
original_part_ids, recv_embeddings, name='parallel_dynamic_stitch')
embeddings = math_ops.sparse_segment_sum(
recv_embeddings,
uniq_idx,
segment_ids,
num_segments=num_segments,
name='sparse_segment_sum')
else:
if isinstance(embedding, dynamic_variable.DynamicVariable):
recv_embeddings = embedding.sparse_read(
all_uniq_ids, lookup_only=(not is_training))
else:
recv_embeddings = array_ops.gather(embedding, all_uniq_ids)
embeddings = math_ops.sparse_segment_sum(
recv_embeddings,
uniq_idx,
segment_ids,
num_segments=num_segments,
name='sparse_segment_sum')
embed_dim = embedding.get_shape()[-1]
output_tensor = array_ops.reshape(embeddings, [N, -1, embed_dim])
if output_tensors is not None:
outputs = array_ops.split(output_tensor, num_or_size_splits=N, axis=0)
for output, output_id in zip(outputs, output_ids):
output_tensors[output_id] = array_ops.squeeze(output, axis=0)
if batch_size is None:
batch_size = -1
return array_ops.reshape(
array_ops.transpose(output_tensor, perm=[1, 0, 2]),
[batch_size, N * embed_dim])