void ComputeAsync()

in tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h [80:162]


  void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    const Tensor& input_data = context->input(0);
    const Tensor& indices = context->input(1);
    const Tensor& segment_ids = context->input(2);

    OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(indices.shape()),
                      errors::InvalidArgument("indices should be a vector."),
                      done);
    const int64 num_indices = indices.NumElements();

    OP_REQUIRES_ASYNC(
        context, TensorShapeUtils::IsVector(segment_ids.shape()),
        errors::InvalidArgument("segment_ids should be a vector."), done);

    OP_REQUIRES_ASYNC(
        context, num_indices == segment_ids.NumElements(),
        errors::InvalidArgument("indices and segment_ids should have"
                                "same length."),
        done);

    ScratchSpace<Tindex> output_rows_host(context, /*size=*/1,
                                          /*on_host=*/true);
    auto stream = context->op_device_context()->stream();

    if (has_num_segments) {
      const Tensor& num_segments = context->input(3);
      output_rows_host.tensor().CopyFrom(num_segments, num_segments.shape());
    } else {
      se::DeviceMemoryBase last_segment_id_on_device(const_cast<Tindex*>(
          segment_ids.template flat<Tindex>().data() + num_indices - 1));
      OP_REQUIRES_ASYNC(
          context,
          stream
              ->ThenMemcpy(output_rows_host.mutable_data(),
                           last_segment_id_on_device, sizeof(Tindex))
              .ok(),
          errors::Internal(
              "SparseSegmentSumGpuOp: failed to copy output_rows to host."),
          done);
    }

    const Tindex input_dims = input_data.dims();
    OP_REQUIRES_ASYNC(
        context, input_dims >= 1,
        errors::InvalidArgument("indices and segment_ids should have "
                                "same length."),
        done);

    Tindex element_size = 1;
    const TensorShape input_shape = input_data.shape();
    if (input_dims > 1) {
      for (Tindex i = 1; i < input_dims; i++) {
        element_size *= input_shape.dim_size(i);
      }
    }

    OP_REQUIRES_OK_ASYNC(context, stream->BlockHostUntilDone(), done);
    Tindex output_rows = *output_rows_host.data();
    // Since segment_ids counts from 0 for output position, the output_rows
    // is increased by 1, if there is no specified num_segments value.
    if (!has_num_segments) {
      output_rows++;
    }
    OP_REQUIRES_ASYNC(context, output_rows > 0,
                      errors::InvalidArgument("Segment ids must be >= 0"),
                      done);
    TensorShape output_shape = input_data.shape();
    output_shape.set_dim(0, output_rows);

    Tensor* output = nullptr;
    OP_REQUIRES_OK_ASYNC(
        context, context->allocate_output(0, output_shape, &output), done);

    functor::SparseSegmentSumFunctor<T, Tindex> executant(
        output_rows, num_indices, num_indices * element_size, input_data,
        indices, segment_ids, output);

    ScopedActivateExecutorContext scoped_activation{stream->parent()};
    executant(context, context->eigen_device<GPUDevice>());

    context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
        stream, done);
  }