TfLiteStatus EvalT()

in tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc [59:132]


TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) {
  TfLiteTensor& input_starts =
      context->tensors[node->inputs->data[kInputStarts]];
  TfLiteTensor& input_limits =
      context->tensors[node->inputs->data[kInputLimits]];
  TfLiteTensor& input_deltas =
      context->tensors[node->inputs->data[kInputDeltas]];
  // Determine which tensors we need to broadcast.
  const bool broadcast_starts = NumElements(&input_starts) == 1;
  const bool broadcast_limits = NumElements(&input_limits) == 1;
  const bool broadcast_deltas = NumElements(&input_deltas) == 1;

  // nrows (number of output rows) is the size of the non-broadcast inputs,
  // or 1 if all inputs are scalars.
  std::vector<int> in_sizes;
  if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]);
  if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]);
  if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]);
  if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes),
                         std::not_equal_to<>()) != std::end(in_sizes)) {
    context->ReportError(
        context,
        "Invalid argument: starts, limits, and deltas must have the "
        "same shape");
    return kTfLiteError;
  }

  const SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes.front();

  const T* starts = GetTensorData<T>(&input_starts);
  const T* limits = GetTensorData<T>(&input_limits);
  const T* deltas = GetTensorData<T>(&input_deltas);

  TfLiteTensor& rt_nested_splits_out =
      context->tensors[node->outputs->data[kOutputNestedSplits]];
  TF_LITE_ENSURE_OK(context,
                    context->ResizeTensor(context, &rt_nested_splits_out,
                                          IntArrayFromInt(nrows + 1)));
  SPLITS_TYPE* rt_nested_splits =
      GetTensorData<SPLITS_TYPE>(&rt_nested_splits_out);
  rt_nested_splits[0] = 0;

  for (int row = 0; row < nrows; ++row) {
    const T start = broadcast_starts ? starts[0] : starts[row];
    const T limit = broadcast_limits ? limits[0] : limits[row];
    const T delta = broadcast_deltas ? deltas[0] : deltas[row];
    if (delta == 0) {
      context->ReportError(context, "Invalid argument: Requires delta != 0");
      return kTfLiteError;
    }
    rt_nested_splits[row + 1] =
        rt_nested_splits[row] + RangeSize<T, SPLITS_TYPE>(start, limit, delta);
  }
  const SPLITS_TYPE nvals = rt_nested_splits[nrows];

  TfLiteTensor& rt_dense_values_out =
      context->tensors[node->outputs->data[kOutputDenseValues]];
  TF_LITE_ENSURE_OK(context,
                    context->ResizeTensor(context, &rt_dense_values_out,
                                          IntArrayFromInt(nvals)));
  T* rt_dense_values = GetTensorData<T>(&rt_dense_values_out);
  int value_index = 0;
  for (int row = 0; row < nrows; ++row) {
    const SPLITS_TYPE row_size =
        rt_nested_splits[row + 1] - rt_nested_splits[row];
    T value = broadcast_starts ? starts[0] : starts[row];
    const T delta = broadcast_deltas ? deltas[0] : deltas[row];
    for (SPLITS_TYPE i = 0; i < row_size; ++i) {
      rt_dense_values[value_index++] = value;
      value += delta;
    }
  }
  return kTfLiteOk;
}