void Compute()

in tensorflow_text/core/kernels/tokenizer_from_logits_kernel.cc [128:223]


  void Compute(OpKernelContext* ctx) override {
    const Tensor* strings;
    OP_REQUIRES_OK(ctx, ctx->input("strings", &strings));
    const Tensor* logits;
    OP_REQUIRES_OK(ctx, ctx->input("logits", &logits));
    OP_REQUIRES(ctx, strings->dim_size(0) == logits->dim_size(0),
                errors::InvalidArgument("Expecting logits to have ",
                                        strings->dim_size(0),
                                        " rows, got ",
                                        logits->dim_size(0)));
    const Tensor* force_split_at_break_character;
    OP_REQUIRES_OK(ctx, ctx->input("force_split_at_break_character",
                                   &force_split_at_break_character));
    const bool force_split_at_break_character_bool =
        force_split_at_break_character->scalar<bool>()();

    std::vector<string> tokens;
    std::vector<int> begin_offset;
    std::vector<int> end_offset;
    std::vector<int> output_row_splits(1, 0);

    // Tensor to access values from logits.
    const TTypes<const float, 3>::Tensor logits_tensor =
        logits->tensor<float, 3>();

    // Iterate through all the values and tokenize them.
    const auto& strings_vec = strings->flat<tstring>();
    OP_REQUIRES(ctx, logits_tensor.dimension(0) >= strings_vec.size(),
                errors::Internal("Bad logits dimension #0: ",
                                 logits_tensor.dimension(0), " < ",
                                 strings_vec.size()));
    // Dimension #1 of logits will be checked inside TokenizeByLogits.
    OP_REQUIRES(ctx, logits_tensor.dimension(2) == 2,
                errors::Internal("Bad logits dimension #2: ",
                                 logits_tensor.dimension(2), " != 2"));
    for (int i = 0; i < strings_vec.size(); ++i) {
      // Tokenize into tokens and record the offset locations.
      int num_tokens = 0;
      OP_REQUIRES_OK(
          ctx, TokenizeByLogits(
                   strings_vec(i),
                   logits_tensor, i,
                   force_split_at_break_character_bool,
                   &tokens, &begin_offset, &end_offset, &num_tokens));

      // Record the row splits.
      output_row_splits.push_back(num_tokens + output_row_splits.back());
    }

    std::vector<int64> output_tokens_shape;
    output_tokens_shape.push_back(tokens.size());

    std::vector<int64> output_row_splits_shape;
    output_row_splits_shape.push_back(output_row_splits.size());

    Tensor* output_values;
    OP_REQUIRES_OK(ctx, ctx->allocate_output("output_values",
                                             TensorShape(output_tokens_shape),
                                             &output_values));
    auto output_values_vec = output_values->vec<tstring>();

    Tensor* output_row_splits_tensor;
    OP_REQUIRES_OK(ctx,
                   ctx->allocate_output("row_splits",
                                        TensorShape(output_row_splits_shape),
                                        &output_row_splits_tensor));
    auto output_row_splits_vec = output_row_splits_tensor->vec<int64>();

    Tensor* start_values;
    OP_REQUIRES_OK(ctx, ctx->allocate_output("start_values",
                                             TensorShape(output_tokens_shape),
                                             &start_values));
    auto start_values_vec = start_values->vec<int64>();

    Tensor* limit_values;
    OP_REQUIRES_OK(ctx, ctx->allocate_output("limit_values",
                                             TensorShape(output_tokens_shape),
                                             &limit_values));
    auto limit_values_vec = limit_values->vec<int64>();

    for (int i = 0; i < tokens.size(); ++i) {
      output_values_vec(i) = tokens[i];
    }

    for (int i = 0; i < output_row_splits.size(); ++i) {
      output_row_splits_vec(i) = output_row_splits[i];
    }

    for (int i = 0; i < begin_offset.size(); ++i) {
      start_values_vec(i) = begin_offset[i];
    }

    for (int i = 0; i < end_offset.size(); ++i) {
      limit_values_vec(i) = end_offset[i];
    }
  }