in tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc [439:593]
inline void LstmStepInteger8x8_16(
const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
int32_t effective_input_to_input_scale_a,
int32_t effective_input_to_input_scale_b,
const int8_t* input_to_forget_weight_ptr,
int32_t effective_input_to_forget_scale_a,
int32_t effective_input_to_forget_scale_b,
const int8_t* input_to_cell_weight_ptr,
int32_t effective_input_to_cell_scale_a,
int32_t effective_input_to_cell_scale_b,
const int8_t* input_to_output_weight_ptr,
int32_t effective_input_to_output_scale_a,
int32_t effective_input_to_output_scale_b,
const int8_t* recurrent_to_input_weight_ptr,
int32_t effective_recurrent_to_input_scale_a,
int32_t effective_recurrent_to_input_scale_b,
const int8_t* recurrent_to_forget_weight_ptr,
int32_t effective_recurrent_to_forget_scale_a,
int32_t effective_recurrent_to_forget_scale_b,
const int8_t* recurrent_to_cell_weight_ptr,
int32_t effective_recurrent_to_cell_scale_a,
int32_t effective_recurrent_to_cell_scale_b,
const int8_t* recurrent_to_output_weight_ptr,
int32_t effective_recurrent_to_output_scale_a,
int32_t effective_recurrent_to_output_scale_b,
const int16_t* cell_to_input_weight_ptr,
int32_t effective_cell_to_input_scale_a,
int32_t effective_cell_to_input_scale_b,
const int16_t* cell_to_forget_weight_ptr,
int32_t effective_cell_to_forget_scale_a,
int32_t effective_cell_to_forget_scale_b,
const int16_t* cell_to_output_weight_ptr,
int32_t effective_cell_to_output_scale_a,
int32_t effective_cell_to_output_scale_b,
const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
int32_t effective_proj_scale_b, int32_t hidden_zp,
int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
const int16_t* layer_norm_input_weight_ptr,
int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
const int16_t* layer_norm_forget_weight_ptr,
int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
int32_t layer_norm_cell_scale_b,
const int16_t* layer_norm_output_weight_ptr,
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
int16_t quantized_cell_clip, int8_t quantized_proj_clip,
int32_t cell_state_scale, int32_t input_variance_guard,
int32_t forget_variance_guard, int32_t cell_variance_guard,
int32_t output_variance_guard,
const int32_t* input_to_forget_effective_bias,
const int32_t* recurrent_to_forget_effective_bias,
const int32_t* input_to_cell_effective_bias,
const int32_t* recurrent_to_cell_effective_bias,
const int32_t* input_to_output_effective_bias,
const int32_t* recurrent_to_output_effective_bias,
const int32_t* input_to_input_effective_bias,
const int32_t* recurrent_to_input_effective_bias,
const int32_t* projection_effective_bias, int n_batch, int n_cell,
int n_input, int n_output, int8_t* output_state_ptr,
int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int8_t* scratch4, int32_t* scratch5) {
// ruy::profiler::ScopeLabel label("LstmStepInteger8x8_16");
// Make named scratch buffers for the different gates.
int16_t* input_gate_scratch = scratch0;
int16_t* forget_gate_scratch = scratch1;
int16_t* cell_gate_scratch = scratch2;
int16_t* output_gate_scratch = scratch3;
// Since we have already checked that weights are all there or none, we
// can check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weight_ptr == nullptr);
// Check for nullptrs.
TFLITE_DCHECK(input_to_forget_effective_bias);
TFLITE_DCHECK(recurrent_to_forget_effective_bias);
TFLITE_DCHECK(input_to_cell_effective_bias);
TFLITE_DCHECK(recurrent_to_cell_effective_bias);
TFLITE_DCHECK(input_to_output_effective_bias);
TFLITE_DCHECK(recurrent_to_output_effective_bias);
if (!use_cifg) {
TFLITE_DCHECK(input_to_input_effective_bias);
TFLITE_DCHECK(recurrent_to_input_effective_bias);
}
const bool use_projection = (projection_weight_ptr != nullptr);
if (use_projection) {
TFLITE_DCHECK(projection_effective_bias);
}
if (!use_cifg) {
// Calculate the input gate. (If not CIFG.)
CalculateLstmGateInteger8x8_16(
input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
effective_input_to_input_scale_a, effective_input_to_input_scale_b,
output_state_ptr, recurrent_to_input_weight_ptr,
recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
effective_recurrent_to_input_scale_b, cell_state_ptr,
cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
input_variance_guard, n_batch, n_input, n_output, n_cell,
kTfLiteActSigmoid, input_gate_scratch, scratch5);
}
// Calculate the forget gate.
CalculateLstmGateInteger8x8_16(
input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
output_state_ptr, recurrent_to_forget_weight_ptr,
recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
effective_recurrent_to_forget_scale_b, cell_state_ptr,
cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
forget_gate_bias_ptr, layer_norm_forget_scale_a,
layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, scratch5);
// Calculate the cell update gate.
CalculateLstmGateInteger8x8_16(
input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
output_state_ptr, recurrent_to_cell_weight_ptr,
recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
effective_recurrent_to_cell_scale_b, cell_state_ptr,
/*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
/*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
cell_gate_scratch, scratch5);
// Update the cell state.
UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
input_gate_scratch, forget_gate_scratch,
cell_gate_scratch, use_cifg, quantized_cell_clip);
// Calculate the output gate.
CalculateLstmGateInteger8x8_16(
input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
effective_input_to_output_scale_a, effective_input_to_output_scale_b,
output_state_ptr, recurrent_to_output_weight_ptr,
recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
effective_recurrent_to_output_scale_b, cell_state_ptr,
cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
output_gate_bias_ptr, layer_norm_output_scale_a,
layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, scratch5);
// Update the output state.
CalculateLstmOutputInteger8x8_16(
n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
hidden_zp, projection_weight_ptr, effective_proj_scale_a,
effective_proj_scale_b, projection_effective_bias, output_state_zp,
quantized_proj_clip, output_state_ptr, scratch0, scratch4, scratch5);
// Copy output state to the output. Note that unlike float or hybrid, output
// is always contiguous.
std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
}