in tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc [1057:1212]
TfLiteStatus EvalInteger8x8_8(
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* input_to_input_weights,
const TfLiteEvalTensor* input_to_forget_weights,
const TfLiteEvalTensor* input_to_cell_weights,
const TfLiteEvalTensor* input_to_output_weights,
const TfLiteEvalTensor* recurrent_to_input_weights,
const TfLiteEvalTensor* recurrent_to_forget_weights,
const TfLiteEvalTensor* recurrent_to_cell_weights,
const TfLiteEvalTensor* recurrent_to_output_weights,
const TfLiteEvalTensor* cell_to_input_weights,
const TfLiteEvalTensor* cell_to_forget_weights,
const TfLiteEvalTensor* cell_to_output_weights,
const TfLiteEvalTensor* input_layer_norm_coefficients,
const TfLiteEvalTensor* forget_layer_norm_coefficients,
const TfLiteEvalTensor* cell_layer_norm_coefficients,
const TfLiteEvalTensor* output_layer_norm_coefficients,
const TfLiteEvalTensor* input_gate_bias,
const TfLiteEvalTensor* forget_gate_bias,
const TfLiteEvalTensor* cell_gate_bias,
const TfLiteEvalTensor* output_gate_bias,
const TfLiteEvalTensor* projection_weights,
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
TfLiteEvalTensor* output,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteEvalTensor* scratch0, TfLiteEvalTensor* scratch1,
TfLiteEvalTensor* scratch2, TfLiteEvalTensor* scratch3,
TfLiteEvalTensor* scratch4, TfLiteEvalTensor* scratch5,
TfLiteEvalTensor* scratch6, TfLiteEvalTensor* scratch7) {
TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch;
if (input->dims->size == 2) {
max_time = 1;
n_batch = input->dims->data[0];
} else {
max_time = input->dims->data[0];
n_batch = input->dims->data[1];
}
// n_cell and n_output will be the same size when there is no projection.
const int n_cell = input_to_output_weights->dims->data[0];
const int n_output = recurrent_to_output_weights->dims->data[1];
//@TODO input zero point and output zeropoint
// const int32_t input_zp = input->params.zero_point;
/// const int32_t output_state_zp = output_state->params.zero_point;
const int32_t input_zp = 0;
const int32_t output_state_zp = 0;
// Get params for time/batch/sequence.
const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1];
const int input_step = n_batch * n_input;
const int output_step = n_batch * output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
const int t_rel = t;
int8_t* output_ptr =
tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
// Input can be int8 asymmetric or int16 symmetric.
const int8_t* input_ptr =
tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
lstm_eval::LstmStepInteger8x8_8(
input_ptr, input_zp,
tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
integer_lstm_param->effective_input_to_input_scale_a,
integer_lstm_param->effective_input_to_input_scale_b,
tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
integer_lstm_param->effective_input_to_forget_scale_a,
integer_lstm_param->effective_input_to_forget_scale_b,
tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
integer_lstm_param->effective_input_to_cell_scale_a,
integer_lstm_param->effective_input_to_cell_scale_b,
tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
integer_lstm_param->effective_input_to_output_scale_a,
integer_lstm_param->effective_input_to_output_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
integer_lstm_param->effective_recurrent_to_input_scale_a,
integer_lstm_param->effective_recurrent_to_input_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
integer_lstm_param->effective_recurrent_to_forget_scale_a,
integer_lstm_param->effective_recurrent_to_forget_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
integer_lstm_param->effective_recurrent_to_cell_scale_a,
integer_lstm_param->effective_recurrent_to_cell_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
integer_lstm_param->effective_recurrent_to_output_scale_a,
integer_lstm_param->effective_recurrent_to_output_scale_b,
tflite::micro::GetTensorData<int8_t>(cell_to_input_weights),
integer_lstm_param->effective_cell_to_input_scale_a,
integer_lstm_param->effective_cell_to_input_scale_b,
tflite::micro::GetTensorData<int8_t>(cell_to_forget_weights),
integer_lstm_param->effective_cell_to_forget_scale_a,
integer_lstm_param->effective_cell_to_forget_scale_b,
tflite::micro::GetTensorData<int8_t>(cell_to_output_weights),
integer_lstm_param->effective_cell_to_output_scale_a,
integer_lstm_param->effective_cell_to_output_scale_b,
tflite::micro::GetTensorData<int8_t>(projection_weights),
integer_lstm_param->effective_proj_scale_a,
integer_lstm_param->effective_proj_scale_b,
tflite::micro::GetTensorData<int16_t>(input_layer_norm_coefficients),
integer_lstm_param->layer_norm_input_scale_a,
integer_lstm_param->layer_norm_input_scale_b,
tflite::micro::GetTensorData<int16_t>(forget_layer_norm_coefficients),
integer_lstm_param->layer_norm_forget_scale_a,
integer_lstm_param->layer_norm_forget_scale_b,
tflite::micro::GetTensorData<int16_t>(cell_layer_norm_coefficients),
integer_lstm_param->layer_norm_cell_scale_a,
integer_lstm_param->layer_norm_cell_scale_b,
tflite::micro::GetTensorData<int16_t>(output_layer_norm_coefficients),
integer_lstm_param->layer_norm_output_scale_a,
integer_lstm_param->layer_norm_output_scale_b,
tflite::micro::GetTensorData<int32_t>(input_gate_bias),
tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
tflite::micro::GetTensorData<int32_t>(output_gate_bias),
tflite::micro::GetTensorData<int32_t>(projection_bias),
params, integer_lstm_param->intermediate_scale_a,
integer_lstm_param->intermediate_scale_b,
integer_lstm_param->intermediate_zp,
integer_lstm_param->quantized_cell_clip,
integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
n_output, output_batch_leading_dim,
tflite::micro::GetTensorData<int8_t>(output_state), output_state_zp,
tflite::micro::GetTensorData<int16_t>(cell_state), output_ptr,
tflite::micro::GetTensorData<int8_t>(scratch0),
tflite::micro::GetTensorData<int8_t>(scratch1),
tflite::micro::GetTensorData<int16_t>(scratch2),
tflite::micro::GetTensorData<int16_t>(scratch3),
tflite::micro::GetTensorData<int16_t>(scratch4),
tflite::micro::GetTensorData<int16_t>(scratch5),
tflite::micro::GetTensorData<int16_t>(scratch6),
tflite::micro::GetTensorData<int16_t>(scratch7));
}
return kTfLiteOk;
}