void Compute()

in tensorflow_addons/custom_ops/text/cc/kernels/skip_gram_kernels.cc [40:125]


  void Compute(OpKernelContext* context) override {
    const Tensor* input_tensor;
    OP_REQUIRES_OK(context, context->input("input_tensor", &input_tensor));
    const auto input = input_tensor->flat<T>();

    const Tensor* min_skips_tensor;
    OP_REQUIRES_OK(context, context->input("min_skips", &min_skips_tensor));
    const int min_skips = *(min_skips_tensor->scalar<int>().data());
    const Tensor* max_skips_tensor;
    OP_REQUIRES_OK(context, context->input("max_skips", &max_skips_tensor));
    const int max_skips = *(max_skips_tensor->scalar<int>().data());

    const Tensor& input_check = context->input(0);
    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_check.shape()),
                errors::InvalidArgument("input_tensor must be of rank 1"));

    OP_REQUIRES(
        context, min_skips >= 0 && max_skips >= 0,
        errors::InvalidArgument("Both min_skips and max_skips must be >= 0."));
    OP_REQUIRES(context, min_skips <= max_skips,
                errors::InvalidArgument("min_skips must be <= max_skips."));

    const Tensor* start_tensor;
    OP_REQUIRES_OK(context, context->input("start", &start_tensor));
    const int start = *(start_tensor->scalar<int>().data());
    const Tensor* limit_tensor;
    OP_REQUIRES_OK(context, context->input("limit", &limit_tensor));
    const int limit = *(limit_tensor->scalar<int>().data());
    const int end =
        limit < 0 ? input.size()
                  : std::min(start + limit, static_cast<int>(input.size()));

    const Tensor* emit_self_tensor;
    OP_REQUIRES_OK(context,
                   context->input("emit_self_as_target", &emit_self_tensor));
    const bool emit_self_as_target = *(emit_self_tensor->scalar<bool>().data());

    std::vector<T> tokens;
    std::vector<T> labels;

    // Reserve the number of random numbers we will use - we use one for each
    // token between start and end.
    random::PhiloxRandom local_gen =
        generator_.ReserveSamples32(end - start + 1);
    random::SimplePhilox rng(&local_gen);

    // For each token in the sentence, pick a random skip, then generates
    // (token, label) pairs for all labels whose distances from the token are
    // within the range [-skip, skip].
    for (int i = start; i < end; ++i) {
      const int skips = min_skips + rng.Uniform(max_skips - min_skips + 1);
      for (int j = -skips; j <= skips; ++j) {
        if ((i + j < start) || (i + j >= end) ||
            (j == 0 && !emit_self_as_target)) {
          continue;
        }
        tokens.push_back(input(i));
        labels.push_back(input(i + j));
      }
    }

    Tensor* tokens_output = nullptr;
    OP_REQUIRES_OK(context,
                   context->allocate_output(
                       "tokens", TensorShape({static_cast<int>(tokens.size())}),
                       &tokens_output));
    Tensor* labels_output = nullptr;
    OP_REQUIRES_OK(context,
                   context->allocate_output(
                       "labels", TensorShape({static_cast<int>(labels.size())}),
                       &labels_output));
    OP_REQUIRES(
        context, tokens_output->IsSameSize(*labels_output),
        errors::Internal(strings::StrCat(
            "Mismatch between tokens_output shape of ",
            tokens_output->shape().DebugString(),
            " and labels_output shape of ",
            labels_output->shape().DebugString(),
            ". This should never happen - contact ami-team@ if it does.")));

    // Copies results to output tensors.
    for (typename std::vector<T>::size_type i = 0; i < tokens.size(); ++i) {
      tokens_output->vec<T>()(i) = tokens[i];
      labels_output->vec<T>()(i) = labels[i];
    }
  }