in tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h [80:162]
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
const Tensor& input_data = context->input(0);
const Tensor& indices = context->input(1);
const Tensor& segment_ids = context->input(2);
OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(indices.shape()),
errors::InvalidArgument("indices should be a vector."),
done);
const int64 num_indices = indices.NumElements();
OP_REQUIRES_ASYNC(
context, TensorShapeUtils::IsVector(segment_ids.shape()),
errors::InvalidArgument("segment_ids should be a vector."), done);
OP_REQUIRES_ASYNC(
context, num_indices == segment_ids.NumElements(),
errors::InvalidArgument("indices and segment_ids should have"
"same length."),
done);
ScratchSpace<Tindex> output_rows_host(context, /*size=*/1,
/*on_host=*/true);
auto stream = context->op_device_context()->stream();
if (has_num_segments) {
const Tensor& num_segments = context->input(3);
output_rows_host.tensor().CopyFrom(num_segments, num_segments.shape());
} else {
se::DeviceMemoryBase last_segment_id_on_device(const_cast<Tindex*>(
segment_ids.template flat<Tindex>().data() + num_indices - 1));
OP_REQUIRES_ASYNC(
context,
stream
->ThenMemcpy(output_rows_host.mutable_data(),
last_segment_id_on_device, sizeof(Tindex))
.ok(),
errors::Internal(
"SparseSegmentSumGpuOp: failed to copy output_rows to host."),
done);
}
const Tindex input_dims = input_data.dims();
OP_REQUIRES_ASYNC(
context, input_dims >= 1,
errors::InvalidArgument("indices and segment_ids should have "
"same length."),
done);
Tindex element_size = 1;
const TensorShape input_shape = input_data.shape();
if (input_dims > 1) {
for (Tindex i = 1; i < input_dims; i++) {
element_size *= input_shape.dim_size(i);
}
}
OP_REQUIRES_OK_ASYNC(context, stream->BlockHostUntilDone(), done);
Tindex output_rows = *output_rows_host.data();
// Since segment_ids counts from 0 for output position, the output_rows
// is increased by 1, if there is no specified num_segments value.
if (!has_num_segments) {
output_rows++;
}
OP_REQUIRES_ASYNC(context, output_rows > 0,
errors::InvalidArgument("Segment ids must be >= 0"),
done);
TensorShape output_shape = input_data.shape();
output_shape.set_dim(0, output_rows);
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
context, context->allocate_output(0, output_shape, &output), done);
functor::SparseSegmentSumFunctor<T, Tindex> executant(
output_rows, num_indices, num_indices * element_size, input_data,
indices, segment_ids, output);
ScopedActivateExecutorContext scoped_activation{stream->parent()};
executant(context, context->eigen_device<GPUDevice>());
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
stream, done);
}