TfLiteStatus Eval()

in tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc [948:1114]


TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  const auto* params =
      reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
          node->builtin_data);
  const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
  //  const bool use_layer_norm = op_data->use_layer_norm;
  //  const bool time_major = params->time_major;

  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(
      context, node, micro::lstm::full::kInputTensor);
  const TfLiteEvalTensor* input_to_input_weights = tflite::micro::GetEvalInput(
      context, node, micro::lstm::full::kInputToInputWeightsTensor);
  const TfLiteEvalTensor* input_to_forget_weights = tflite::micro::GetEvalInput(
      context, node, micro::lstm::full::kInputToForgetWeightsTensor);
  const TfLiteEvalTensor* input_to_cell_weights = tflite::micro::GetEvalInput(
      context, node, micro::lstm::full::kInputToCellWeightsTensor);
  const TfLiteEvalTensor* input_to_output_weights = tflite::micro::GetEvalInput(
      context, node, micro::lstm::full::kInputToOutputWeightsTensor);
  const TfLiteEvalTensor* recurrent_to_input_weights =
      tflite::micro::GetEvalInput(
          context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
  const TfLiteEvalTensor* recurrent_to_forget_weights =
      tflite::micro::GetEvalInput(
          context, node, micro::lstm::full::kRecurrentToForgetWeightsTensor);
  const TfLiteEvalTensor* recurrent_to_cell_weights =
      tflite::micro::GetEvalInput(
          context, node, micro::lstm::full::kRecurrentToCellWeightsTensor);
  const TfLiteEvalTensor* recurrent_to_output_weights =
      tflite::micro::GetEvalInput(
          context, node, micro::lstm::full::kRecurrentToOutputWeightsTensor);
  const TfLiteEvalTensor* cell_to_input_weights = context->GetEvalTensor(
      context,
      node->inputs->data[micro::lstm::full::kCellToInputWeightsTensor]);
  const TfLiteEvalTensor* cell_to_forget_weights = context->GetEvalTensor(
      context,
      node->inputs->data[micro::lstm::full::kCellToForgetWeightsTensor]);
  const TfLiteEvalTensor* cell_to_output_weights = context->GetEvalTensor(
      context,
      node->inputs->data[micro::lstm::full::kCellToOutputWeightsTensor]);
  const TfLiteEvalTensor* input_gate_bias = context->GetEvalTensor(
      context, node->inputs->data[micro::lstm::full::kInputGateBiasTensor]);

  const TfLiteEvalTensor* forget_gate_bias = context->GetEvalTensor(
      context, node->inputs->data[micro::lstm::full::kForgetGateBiasTensor]);
  const TfLiteEvalTensor* cell_gate_bias = context->GetEvalTensor(
      context, node->inputs->data[micro::lstm::full::kCellGateBiasTensor]);
  const TfLiteEvalTensor* output_gate_bias = context->GetEvalTensor(
      context, node->inputs->data[micro::lstm::full::kOutputGateBiasTensor]);

  const TfLiteEvalTensor* projection_weights = context->GetEvalTensor(
      context, node->inputs->data[micro::lstm::full::kProjectionWeightsTensor]);
  const TfLiteEvalTensor* projection_bias = context->GetEvalTensor(
      context, node->inputs->data[micro::lstm::full::kProjectionBiasTensor]);

  TfLiteEvalTensor* output_state = context->GetEvalTensor(
      context, node->inputs->data[micro::lstm::full::kOutputStateTensor]);
  TFLITE_DCHECK(output_state != nullptr);
  TfLiteEvalTensor* cell_state = context->GetEvalTensor(
      context, node->inputs->data[micro::lstm::full::kCellStateTensor]);
  TFLITE_DCHECK(cell_state != nullptr);
  const TfLiteEvalTensor* input_layer_norm_coefficients =
      context->GetEvalTensor(
          context,
          node->inputs
              ->data[micro::lstm::full::kInputLayerNormCoefficientsTensor]);

  const TfLiteEvalTensor* forget_layer_norm_coefficients =
      context->GetEvalTensor(
          context,
          node->inputs
              ->data[micro::lstm::full::kForgetLayerNormCoefficientsTensor]);
  const TfLiteEvalTensor* cell_layer_norm_coefficients = context->GetEvalTensor(
      context,
      node->inputs->data[micro::lstm::full::kCellLayerNormCoefficientsTensor]);

  const TfLiteEvalTensor* output_layer_norm_coefficients =
      context->GetEvalTensor(
          context,
          node->inputs
              ->data[micro::lstm::full::kOutputLayerNormCoefficientsTensor]);

  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(
      context, node, micro::lstm::full::kOutputTensor);

  // Copy out the LSTM specific params so they can be passed in the function.
  TfLiteLSTMParams lstm_params;
  lstm_params.activation = params->activation;
  lstm_params.cell_clip = params->cell_clip;
  lstm_params.proj_clip = params->proj_clip;
  lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
  switch (input_to_output_weights->type) {
    case kTfLiteInt8: {
      const bool is_hybrid = input->type == kTfLiteFloat32;
      if (is_hybrid) {
        TF_LITE_KERNEL_LOG(context, " hybrid type is not supported.");
        return kTfLiteError;

      } else {
        TfLiteEvalTensor* scratch[6];
        // Allocate scratch buffer. Need 6 16bit buffer with size n_batch *
        // n_cell
        // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
        // buffer with size n_batch * n_cell.
        //
        // Handle cifg case as well, which might save one buffer.

        const auto* tmp_params =
            reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
                node->builtin_data);
        const bool time_major = tmp_params->time_major;
        for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
          TFLITE_DCHECK(context != nullptr);
          TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
          int32_t* scratch_tensor =
              static_cast<int32_t*>(context->GetScratchBuffer(
                  context, op_data->scratch_tensor_index + scratch_index));
          scratch[scratch_index] = (TfLiteEvalTensor*)scratch_tensor;
        }
        /*
                                TF_LITE_ENSURE_OK(context,
                                                GetScratchSafe(context, node, 0,
           &scratch0));

                                TF_LITE_ENSURE_OK(context,
                                                GetScratchSafe(context, node, 1,
           &scratch1));

                                TF_LITE_ENSURE_OK(context,
                                                GetScratchSafe(context, node, 2,
           &scratch2));

                                TF_LITE_ENSURE_OK(context,
                                                GetScratchSafe(context, node, 3,
           &scratch3));

                                TF_LITE_ENSURE_OK(context,
                                                GetScratchSafe(context, node, 4,
           &scratch4));

                                TF_LITE_ENSURE_OK(context,
                                                GetScratchSafe(context, node, 5,
           &scratch5));
        */
        return lstm_eval::EvalInteger8x8_16(
            context, node, input, input_to_input_weights,
            input_to_forget_weights, input_to_cell_weights,
            input_to_output_weights, recurrent_to_input_weights,
            recurrent_to_forget_weights, recurrent_to_cell_weights,
            recurrent_to_output_weights, cell_to_input_weights,
            cell_to_forget_weights, cell_to_output_weights,
            input_layer_norm_coefficients, forget_layer_norm_coefficients,
            cell_layer_norm_coefficients, output_layer_norm_coefficients,
            input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias,
            projection_weights, projection_bias, &lstm_params,
            /*forward_sequence=*/true, time_major, &op_data->integer_lstm_param,
            output_state, cell_state, output, scratch[0], scratch[1],
            scratch[2], scratch[3], scratch[4], scratch[5]);
      }
    }

    default:
      TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
                         TfLiteTypeGetName(input_to_output_weights->type));
      return kTfLiteError;
  }
  return kTfLiteOk;
}