in tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc [948:1114]
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
node->builtin_data);
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
// const bool use_layer_norm = op_data->use_layer_norm;
// const bool time_major = params->time_major;
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kInputTensor);
const TfLiteEvalTensor* input_to_input_weights = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kInputToInputWeightsTensor);
const TfLiteEvalTensor* input_to_forget_weights = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kInputToForgetWeightsTensor);
const TfLiteEvalTensor* input_to_cell_weights = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kInputToCellWeightsTensor);
const TfLiteEvalTensor* input_to_output_weights = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kInputToOutputWeightsTensor);
const TfLiteEvalTensor* recurrent_to_input_weights =
tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
const TfLiteEvalTensor* recurrent_to_forget_weights =
tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kRecurrentToForgetWeightsTensor);
const TfLiteEvalTensor* recurrent_to_cell_weights =
tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kRecurrentToCellWeightsTensor);
const TfLiteEvalTensor* recurrent_to_output_weights =
tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kRecurrentToOutputWeightsTensor);
const TfLiteEvalTensor* cell_to_input_weights = context->GetEvalTensor(
context,
node->inputs->data[micro::lstm::full::kCellToInputWeightsTensor]);
const TfLiteEvalTensor* cell_to_forget_weights = context->GetEvalTensor(
context,
node->inputs->data[micro::lstm::full::kCellToForgetWeightsTensor]);
const TfLiteEvalTensor* cell_to_output_weights = context->GetEvalTensor(
context,
node->inputs->data[micro::lstm::full::kCellToOutputWeightsTensor]);
const TfLiteEvalTensor* input_gate_bias = context->GetEvalTensor(
context, node->inputs->data[micro::lstm::full::kInputGateBiasTensor]);
const TfLiteEvalTensor* forget_gate_bias = context->GetEvalTensor(
context, node->inputs->data[micro::lstm::full::kForgetGateBiasTensor]);
const TfLiteEvalTensor* cell_gate_bias = context->GetEvalTensor(
context, node->inputs->data[micro::lstm::full::kCellGateBiasTensor]);
const TfLiteEvalTensor* output_gate_bias = context->GetEvalTensor(
context, node->inputs->data[micro::lstm::full::kOutputGateBiasTensor]);
const TfLiteEvalTensor* projection_weights = context->GetEvalTensor(
context, node->inputs->data[micro::lstm::full::kProjectionWeightsTensor]);
const TfLiteEvalTensor* projection_bias = context->GetEvalTensor(
context, node->inputs->data[micro::lstm::full::kProjectionBiasTensor]);
TfLiteEvalTensor* output_state = context->GetEvalTensor(
context, node->inputs->data[micro::lstm::full::kOutputStateTensor]);
TFLITE_DCHECK(output_state != nullptr);
TfLiteEvalTensor* cell_state = context->GetEvalTensor(
context, node->inputs->data[micro::lstm::full::kCellStateTensor]);
TFLITE_DCHECK(cell_state != nullptr);
const TfLiteEvalTensor* input_layer_norm_coefficients =
context->GetEvalTensor(
context,
node->inputs
->data[micro::lstm::full::kInputLayerNormCoefficientsTensor]);
const TfLiteEvalTensor* forget_layer_norm_coefficients =
context->GetEvalTensor(
context,
node->inputs
->data[micro::lstm::full::kForgetLayerNormCoefficientsTensor]);
const TfLiteEvalTensor* cell_layer_norm_coefficients = context->GetEvalTensor(
context,
node->inputs->data[micro::lstm::full::kCellLayerNormCoefficientsTensor]);
const TfLiteEvalTensor* output_layer_norm_coefficients =
context->GetEvalTensor(
context,
node->inputs
->data[micro::lstm::full::kOutputLayerNormCoefficientsTensor]);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(
context, node, micro::lstm::full::kOutputTensor);
// Copy out the LSTM specific params so they can be passed in the function.
TfLiteLSTMParams lstm_params;
lstm_params.activation = params->activation;
lstm_params.cell_clip = params->cell_clip;
lstm_params.proj_clip = params->proj_clip;
lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
switch (input_to_output_weights->type) {
case kTfLiteInt8: {
const bool is_hybrid = input->type == kTfLiteFloat32;
if (is_hybrid) {
TF_LITE_KERNEL_LOG(context, " hybrid type is not supported.");
return kTfLiteError;
} else {
TfLiteEvalTensor* scratch[6];
// Allocate scratch buffer. Need 6 16bit buffer with size n_batch *
// n_cell
// and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
// buffer with size n_batch * n_cell.
//
// Handle cifg case as well, which might save one buffer.
const auto* tmp_params =
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
node->builtin_data);
const bool time_major = tmp_params->time_major;
for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
int32_t* scratch_tensor =
static_cast<int32_t*>(context->GetScratchBuffer(
context, op_data->scratch_tensor_index + scratch_index));
scratch[scratch_index] = (TfLiteEvalTensor*)scratch_tensor;
}
/*
TF_LITE_ENSURE_OK(context,
GetScratchSafe(context, node, 0,
&scratch0));
TF_LITE_ENSURE_OK(context,
GetScratchSafe(context, node, 1,
&scratch1));
TF_LITE_ENSURE_OK(context,
GetScratchSafe(context, node, 2,
&scratch2));
TF_LITE_ENSURE_OK(context,
GetScratchSafe(context, node, 3,
&scratch3));
TF_LITE_ENSURE_OK(context,
GetScratchSafe(context, node, 4,
&scratch4));
TF_LITE_ENSURE_OK(context,
GetScratchSafe(context, node, 5,
&scratch5));
*/
return lstm_eval::EvalInteger8x8_16(
context, node, input, input_to_input_weights,
input_to_forget_weights, input_to_cell_weights,
input_to_output_weights, recurrent_to_input_weights,
recurrent_to_forget_weights, recurrent_to_cell_weights,
recurrent_to_output_weights, cell_to_input_weights,
cell_to_forget_weights, cell_to_output_weights,
input_layer_norm_coefficients, forget_layer_norm_coefficients,
cell_layer_norm_coefficients, output_layer_norm_coefficients,
input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_bias, &lstm_params,
/*forward_sequence=*/true, time_major, &op_data->integer_lstm_param,
output_state, cell_state, output, scratch[0], scratch[1],
scratch[2], scratch[3], scratch[4], scratch[5]);
}
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
TfLiteTypeGetName(input_to_output_weights->type));
return kTfLiteError;
}
return kTfLiteOk;
}