in tensorflow/lite/micro/kernels/xtensa/unidirectional_sequence_lstm.cc [415:659]
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell,
bool use_layer_norm, bool is_integer) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
// > 0 means clipping
TF_LITE_ENSURE(context, params->cell_clip >= 0);
TF_LITE_ENSURE(context, params->proj_clip >= 0);
const TfLiteEvalTensor* input_to_input_weights = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kInputToInputWeightsTensor);
if (input_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
}
const TfLiteEvalTensor* input_to_forget_weights = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kInputToForgetWeightsTensor);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
const TfLiteEvalTensor* input_to_cell_weights = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kInputToCellWeightsTensor);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
const TfLiteEvalTensor* recurrent_to_input_weights =
tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kRecurrentToInputWeightsTensor);
if (recurrent_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
n_output);
}
const TfLiteEvalTensor* recurrent_to_forget_weights =
tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kRecurrentToForgetWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
n_output);
const TfLiteEvalTensor* recurrent_to_cell_weights =
tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kRecurrentToCellWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
n_output);
// We make sure the input-gate's parameters are either both present (regular
// LSTM) or not at all (CIFG-LSTM).
const bool cifg_weights_all_or_none =
((input_to_input_weights != nullptr) &&
(recurrent_to_input_weights != nullptr)) ||
((input_to_input_weights == nullptr) &&
(recurrent_to_input_weights == nullptr));
TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
context, node, micro::lstm::full::kCellToInputWeightsTensor);
if (cell_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(
context, cell_to_input_weights->type,
is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
}
const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToForgetWeightsTensor);
if (cell_to_forget_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(
context, cell_to_forget_weights->type,
is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
}
const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
context, node, micro::lstm::full::kCellToOutputWeightsTensor);
if (cell_to_output_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(
context, cell_to_output_weights->type,
is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
}
// Making sure the peephole weights are there all or none.
const bool use_cifg = (input_to_input_weights == nullptr);
const bool peephole_weights_all_or_none =
((cell_to_input_weights != nullptr || use_cifg) &&
(cell_to_forget_weights != nullptr) &&
(cell_to_output_weights != nullptr)) ||
((cell_to_input_weights == nullptr) &&
(cell_to_forget_weights == nullptr) &&
(cell_to_output_weights == nullptr));
TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
const TfLiteEvalTensor* input_gate_bias = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kInputGateBiasTensor);
if (use_cifg) {
TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
} else {
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
}
}
const TfLiteEvalTensor* forget_gate_bias = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kForgetGateBiasTensor);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
}
const TfLiteEvalTensor* cell_gate_bias = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kCellGateBiasTensor);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
}
const TfLiteEvalTensor* output_gate_bias = tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kOutputGateBiasTensor);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, micro::lstm::full::kProjectionWeightsTensor);
if (projection_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
}
const TfLiteTensor* projection_bias = GetOptionalInputTensor(
context, node, micro::lstm::full::kProjectionBiasTensor);
if (projection_bias != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
}
}
// Making sure the projection tensors are consistent:
// 1) If projection weight is not present, then projection bias should not be
// present.
// 2) If projection weight is present, then projection bias is optional.
const bool projecton_tensors_consistent =
((projection_weights != nullptr) || (projection_bias == nullptr));
TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
if (use_layer_norm) {
const TfLiteEvalTensor* input_layer_norm_coefficients =
tflite::micro::GetEvalInput(
context, node,
micro::lstm::full::kInputLayerNormCoefficientsTensor);
if (use_cifg) {
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
} else {
TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
kTfLiteFloat32);
}
}
const TfLiteEvalTensor* forget_layer_norm_coefficients =
tflite::micro::GetEvalInput(
context, node,
micro::lstm::full::kForgetLayerNormCoefficientsTensor);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteFloat32);
}
const TfLiteEvalTensor* cell_layer_norm_coefficients =
tflite::micro::GetEvalInput(
context, node, micro::lstm::full::kCellLayerNormCoefficientsTensor);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteFloat32);
}
const TfLiteEvalTensor* output_layer_norm_coefficients =
tflite::micro::GetEvalInput(
context, node,
micro::lstm::full::kOutputLayerNormCoefficientsTensor);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
kTfLiteFloat32);
}
}
return kTfLiteOk;
}