in tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc [807:1055]
TfLiteStatus EvalInteger8x8_16(
TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* input_to_input_weights,
const TfLiteEvalTensor* input_to_forget_weights,
const TfLiteEvalTensor* input_to_cell_weights,
const TfLiteEvalTensor* input_to_output_weights,
const TfLiteEvalTensor* recurrent_to_input_weights,
const TfLiteEvalTensor* recurrent_to_forget_weights,
const TfLiteEvalTensor* recurrent_to_cell_weights,
const TfLiteEvalTensor* recurrent_to_output_weights,
const TfLiteEvalTensor* cell_to_input_weights,
const TfLiteEvalTensor* cell_to_forget_weights,
const TfLiteEvalTensor* cell_to_output_weights,
const TfLiteEvalTensor* input_layer_norm_coefficients,
const TfLiteEvalTensor* forget_layer_norm_coefficients,
const TfLiteEvalTensor* cell_layer_norm_coefficients,
const TfLiteEvalTensor* output_layer_norm_coefficients,
const TfLiteEvalTensor* input_gate_bias,
const TfLiteEvalTensor* forget_gate_bias,
const TfLiteEvalTensor* cell_gate_bias,
const TfLiteEvalTensor* output_gate_bias,
const TfLiteEvalTensor* projection_weights,
const TfLiteEvalTensor* projection_bias, const TfLiteLSTMParams* params,
bool forward_sequence, bool time_major,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteEvalTensor* output_state, TfLiteEvalTensor* cell_state,
TfLiteEvalTensor* output, TfLiteEvalTensor* scratch0,
TfLiteEvalTensor* scratch1, TfLiteEvalTensor* scratch2,
TfLiteEvalTensor* scratch3, TfLiteEvalTensor* scratch4,
TfLiteEvalTensor* scratch5) {
TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch;
if (input->dims->size == 2) {
max_time = 1;
n_batch = input->dims->data[0];
} else {
max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
}
// n_cell and n_output will be the same size when there is no projection.
const int n_cell = input_to_output_weights->dims->data[0];
const int n_output = recurrent_to_output_weights->dims->data[1];
// Activation zero point
// TODO@is data.output_zero_point equal to output_state->params.zero_point
// int output_state_zp = output_state->params.zero_point;
int output_state_zp = 0;
// Get params for time/batch/sequence.
const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1];
if (time_major) {
const int input_step = n_batch * n_input;
const int output_step = n_batch * output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
const int t_rel = t;
int8_t* output_ptr =
tflite::micro::GetTensorData<int8_t>(output) + t_rel * output_step;
const int8_t* input_ptr =
tflite::micro::GetTensorData<int8_t>(input) + t_rel * input_step;
LstmStepInteger8x8_16(
input_ptr,
tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
integer_lstm_param->effective_input_to_input_scale_a,
integer_lstm_param->effective_input_to_input_scale_b,
tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
integer_lstm_param->effective_input_to_forget_scale_a,
integer_lstm_param->effective_input_to_forget_scale_b,
tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
integer_lstm_param->effective_input_to_cell_scale_a,
integer_lstm_param->effective_input_to_cell_scale_b,
tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
integer_lstm_param->effective_input_to_output_scale_a,
integer_lstm_param->effective_input_to_output_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
integer_lstm_param->effective_recurrent_to_input_scale_a,
integer_lstm_param->effective_recurrent_to_input_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
integer_lstm_param->effective_recurrent_to_forget_scale_a,
integer_lstm_param->effective_recurrent_to_forget_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
integer_lstm_param->effective_recurrent_to_cell_scale_a,
integer_lstm_param->effective_recurrent_to_cell_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
integer_lstm_param->effective_recurrent_to_output_scale_a,
integer_lstm_param->effective_recurrent_to_output_scale_b,
tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
integer_lstm_param->effective_cell_to_input_scale_a,
integer_lstm_param->effective_cell_to_input_scale_b,
tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
integer_lstm_param->effective_cell_to_forget_scale_a,
integer_lstm_param->effective_cell_to_forget_scale_b,
tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
integer_lstm_param->effective_cell_to_output_scale_a,
integer_lstm_param->effective_cell_to_output_scale_b,
tflite::micro::GetTensorData<int8_t>(projection_weights),
integer_lstm_param->effective_proj_scale_a,
integer_lstm_param->effective_proj_scale_b,
integer_lstm_param->hidden_zp,
integer_lstm_param->effective_hidden_scale_a,
integer_lstm_param->effective_hidden_scale_b,
tflite::micro::GetTensorData<int16_t>(input_layer_norm_coefficients),
integer_lstm_param->layer_norm_input_scale_a,
integer_lstm_param->layer_norm_input_scale_b,
tflite::micro::GetTensorData<int16_t>(forget_layer_norm_coefficients),
integer_lstm_param->layer_norm_forget_scale_a,
integer_lstm_param->layer_norm_forget_scale_b,
tflite::micro::GetTensorData<int16_t>(cell_layer_norm_coefficients),
integer_lstm_param->layer_norm_cell_scale_a,
integer_lstm_param->layer_norm_cell_scale_b,
tflite::micro::GetTensorData<int16_t>(output_layer_norm_coefficients),
integer_lstm_param->layer_norm_output_scale_a,
integer_lstm_param->layer_norm_output_scale_b,
tflite::micro::GetTensorData<int32_t>(input_gate_bias),
tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
tflite::micro::GetTensorData<int32_t>(output_gate_bias),
integer_lstm_param->quantized_cell_clip,
integer_lstm_param->quantized_proj_clip,
integer_lstm_param->cell_scale,
integer_lstm_param->input_variance_guard,
integer_lstm_param->forget_variance_guard,
integer_lstm_param->cell_variance_guard,
integer_lstm_param->output_variance_guard,
integer_lstm_param->input_to_forget_effective_bias.get(),
integer_lstm_param->recurrent_to_forget_effective_bias.get(),
integer_lstm_param->input_to_cell_effective_bias.get(),
integer_lstm_param->recurrent_to_cell_effective_bias.get(),
integer_lstm_param->input_to_output_effective_bias.get(),
integer_lstm_param->recurrent_to_output_effective_bias.get(),
integer_lstm_param->input_to_input_effective_bias.get(),
integer_lstm_param->recurrent_to_input_effective_bias.get(),
integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
n_input, n_output, tflite::micro::GetTensorData<int8_t>(output_state),
output_state_zp, tflite::micro::GetTensorData<int16_t>(cell_state),
output_ptr, (int16_t*)(scratch0), (int16_t*)(scratch1),
(int16_t*)(scratch2), (int16_t*)(scratch3), (int8_t*)(scratch4),
(int32_t*)(scratch5));
}
} else {
for (int b = 0; b < n_batch; b++) {
const int input_step = n_input;
const int output_step = output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const int time_offset = b * max_time + t_rel;
const int8_t* input_ptr = tflite::micro::GetTensorData<int8_t>(input) +
time_offset * input_step;
int8_t* output_ptr = tflite::micro::GetTensorData<int8_t>(output) +
time_offset * output_step;
// Offset the {output,cell}_state pointers to the right batch.
int8_t* output_state_ptr =
tflite::micro::GetTensorData<int8_t>(output_state) +
b * output_batch_leading_dim;
int16_t* cell_state_ptr =
tflite::micro::GetTensorData<int16_t>(cell_state) + b * n_cell;
LstmStepInteger8x8_16(
input_ptr,
tflite::micro::GetTensorData<int8_t>(input_to_input_weights),
integer_lstm_param->effective_input_to_input_scale_a,
integer_lstm_param->effective_input_to_input_scale_b,
tflite::micro::GetTensorData<int8_t>(input_to_forget_weights),
integer_lstm_param->effective_input_to_forget_scale_a,
integer_lstm_param->effective_input_to_forget_scale_b,
tflite::micro::GetTensorData<int8_t>(input_to_cell_weights),
integer_lstm_param->effective_input_to_cell_scale_a,
integer_lstm_param->effective_input_to_cell_scale_b,
tflite::micro::GetTensorData<int8_t>(input_to_output_weights),
integer_lstm_param->effective_input_to_output_scale_a,
integer_lstm_param->effective_input_to_output_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_input_weights),
integer_lstm_param->effective_recurrent_to_input_scale_a,
integer_lstm_param->effective_recurrent_to_input_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_forget_weights),
integer_lstm_param->effective_recurrent_to_forget_scale_a,
integer_lstm_param->effective_recurrent_to_forget_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_cell_weights),
integer_lstm_param->effective_recurrent_to_cell_scale_a,
integer_lstm_param->effective_recurrent_to_cell_scale_b,
tflite::micro::GetTensorData<int8_t>(recurrent_to_output_weights),
integer_lstm_param->effective_recurrent_to_output_scale_a,
integer_lstm_param->effective_recurrent_to_output_scale_b,
tflite::micro::GetTensorData<int16_t>(cell_to_input_weights),
integer_lstm_param->effective_cell_to_input_scale_a,
integer_lstm_param->effective_cell_to_input_scale_b,
tflite::micro::GetTensorData<int16_t>(cell_to_forget_weights),
integer_lstm_param->effective_cell_to_forget_scale_a,
integer_lstm_param->effective_cell_to_forget_scale_b,
tflite::micro::GetTensorData<int16_t>(cell_to_output_weights),
integer_lstm_param->effective_cell_to_output_scale_a,
integer_lstm_param->effective_cell_to_output_scale_b,
tflite::micro::GetTensorData<int8_t>(projection_weights),
integer_lstm_param->effective_proj_scale_a,
integer_lstm_param->effective_proj_scale_b,
integer_lstm_param->hidden_zp,
integer_lstm_param->effective_hidden_scale_a,
integer_lstm_param->effective_hidden_scale_b,
tflite::micro::GetTensorData<int16_t>(
input_layer_norm_coefficients),
integer_lstm_param->layer_norm_input_scale_a,
integer_lstm_param->layer_norm_input_scale_b,
tflite::micro::GetTensorData<int16_t>(
forget_layer_norm_coefficients),
integer_lstm_param->layer_norm_forget_scale_a,
integer_lstm_param->layer_norm_forget_scale_b,
tflite::micro::GetTensorData<int16_t>(cell_layer_norm_coefficients),
integer_lstm_param->layer_norm_cell_scale_a,
integer_lstm_param->layer_norm_cell_scale_b,
tflite::micro::GetTensorData<int16_t>(
output_layer_norm_coefficients),
integer_lstm_param->layer_norm_output_scale_a,
integer_lstm_param->layer_norm_output_scale_b,
tflite::micro::GetTensorData<int32_t>(input_gate_bias),
tflite::micro::GetTensorData<int32_t>(forget_gate_bias),
tflite::micro::GetTensorData<int32_t>(cell_gate_bias),
tflite::micro::GetTensorData<int32_t>(output_gate_bias),
integer_lstm_param->quantized_cell_clip,
integer_lstm_param->quantized_proj_clip,
integer_lstm_param->cell_scale,
integer_lstm_param->input_variance_guard,
integer_lstm_param->forget_variance_guard,
integer_lstm_param->cell_variance_guard,
integer_lstm_param->output_variance_guard,
integer_lstm_param->input_to_forget_effective_bias.get(),
integer_lstm_param->recurrent_to_forget_effective_bias.get(),
integer_lstm_param->input_to_cell_effective_bias.get(),
integer_lstm_param->recurrent_to_cell_effective_bias.get(),
integer_lstm_param->input_to_output_effective_bias.get(),
integer_lstm_param->recurrent_to_output_effective_bias.get(),
integer_lstm_param->input_to_input_effective_bias.get(),
integer_lstm_param->recurrent_to_input_effective_bias.get(),
integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1,
n_cell, n_input, n_output, output_state_ptr, output_state_zp,
cell_state_ptr, output_ptr, (int16_t*)(scratch0),
(int16_t*)(scratch1), (int16_t*)(scratch2), (int16_t*)(scratch3),
(int8_t*)(scratch4), (int32_t*)(scratch5));
}
}
}
return kTfLiteOk;
}