TfLiteStatus NNAPIDelegateKernel::AddOpsAndTensors()

in tensorflow/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc [2955:3187]


TfLiteStatus NNAPIDelegateKernel::AddOpsAndTensors(TfLiteContext* context) {
  DequantizeMapping dequantize_mapping;
  // The operand builder allows creating a single op. It is created outside
  // the for loop to avoid reallocating the vectors.
  NNAPIOpBuilder builder(nnapi_, context, &operand_mapping_,
                         &dequantize_mapping, &allocation_memory_mapping_,
                         nn_model_.get());
  // Add Tensors.
  for (auto node_index : nodes_) {
    // Obtain the op and registration.
    TfLiteNode* node;
    TfLiteRegistration* reg;
    TF_LITE_ENSURE_STATUS(
        context->GetNodeAndRegistration(context, node_index, &node, &reg));

    const bool hybrid_op = IsHybridOperator(context, reg->builtin_code, node);
    const bool scalar_as_tensor = IsScalarInputSupported(reg->builtin_code);
    const bool need_int8_conversion =
        NeedInt8Conversion(context, reg->builtin_code, node);
    int input_tensor_flags = 0;
    if (scalar_as_tensor) {
      input_tensor_flags |= NN_TENSOR_FLAG_SCALAR_AS_TENSOR;
    }

    // Map inputs to NN API tensor indices.
    for (int input_pos = 0; input_pos < node->inputs->size; ++input_pos) {
      const auto input_index = node->inputs->data[input_pos];
      if (need_int8_conversion &&
          (input_pos == 0 ||
           reg->builtin_code == kTfLiteBuiltinFullyConnected ||
           reg->builtin_code == kTfLiteBuiltinAdd ||
           reg->builtin_code == kTfLiteBuiltinMul ||
           reg->builtin_code == kTfLiteBuiltinSub ||
           reg->builtin_code == kTfLiteBuiltinConcatenation ||
           reg->builtin_code == kTfLiteBuiltinMaximum ||
           reg->builtin_code == kTfLiteBuiltinMinimum ||
           reg->builtin_code == kTfLiteBuiltinLess ||
           reg->builtin_code == kTfLiteBuiltinLessEqual ||
           reg->builtin_code == kTfLiteBuiltinGreater ||
           reg->builtin_code == kTfLiteBuiltinGreaterEqual ||
           reg->builtin_code == kTfLiteBuiltinEqual ||
           reg->builtin_code == kTfLiteBuiltinNotEqual ||
           reg->builtin_code == kTfLiteBuiltinSelect)) {
        // Only selected inputs require int8 conversion.
        TF_LITE_ENSURE_STATUS(builder.AddTensorInput(
            input_index, hybrid_op,
            input_tensor_flags | NN_TENSOR_FLAG_INT8_CONVERSION));
        continue;
      }
      if (reg->builtin_code == kTfLiteBuiltinLstm && isLstmFullKernel(node) &&
          input_pos >= 20) {
        // Skip layer normalization weights. They are added in the Map
        // function (after all the other inputs added there) since layer
        // normalization weights are the last four inputs of the LSTM op in
        // NNAPI.
        continue;
      }
      if (reg->builtin_code == kTfLiteBuiltinLstm && isLstmBasicKernel(node)) {
        // Configuring all inputs in the Map function
        continue;
      }
      if (reg->builtin_code == kTfLiteBuiltinUnidirectionalSequenceLstm) {
        if (input_pos >= 20) {
          // Skip layer normalization weights. They are added in the Map
          // function (after all the other inputs added there) since layer
          // normalization weights are the last four inputs of the
          // unidirectional sequence LSTM op in NNAPI.
          continue;
        }
        if (input_index == kOptionalTensor) {
          TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0));
          continue;
        }
      }
      if ((reg->builtin_code == kTfLiteBuiltinSplit) &&
          (input_index == node->inputs->data[0])) {
        // Skip the axis input tensor; it will be added as a scalar operand
        // by the Map() mapping.
        continue;
      }
      if (reg->builtin_code == kTfLiteBuiltinTransposeConv) {
        // Everything is added during Map since input tensors
        // have different order.
        continue;
      }

      // Pad and Padv2 have an optional parameter for a pad value which has
      // to be converted to a scalar type in NN API.
      if ((reg->builtin_code == kTfLiteBuiltinPadv2 ||
           reg->builtin_code == kTfLiteBuiltinPad) &&
          node->inputs->size == 3 && input_pos == 2) {
        const int constant_value_id = node->inputs->data[2];
        if (constant_value_id == kOptionalTensor) {
          continue;
        }
        const TfLiteTensor constant_value = context->tensors[constant_value_id];

        switch (constant_value.type) {
          case kTfLiteFloat32:
            if (constant_value.allocation_type == kTfLiteMmapRo) {
              builder.AddScalarFloat32Operand(*constant_value.data.f);
            } else {
              builder.AddSingleValueTensorAsScalarOperand(
                  constant_value_id, ANEURALNETWORKS_FLOAT32);
            }
            break;
          case kTfLiteUInt8:
            if (constant_value.allocation_type == kTfLiteMmapRo) {
              builder.AddScalarInt32Operand(
                  static_cast<int32_t>(*constant_value.data.uint8));
            } else {
              builder.AddSingleValueTensorAsScalarOperand(
                  constant_value_id, ANEURALNETWORKS_INT32);
            }
            break;
          case kTfLiteInt8:
            if (constant_value.allocation_type == kTfLiteMmapRo) {
              builder.AddScalarInt32Operand(
                  static_cast<int32_t>(*constant_value.data.int8) + 128);
            } else {
              builder.AddSingleValueTensorAsScalarOperand(
                  constant_value_id, ANEURALNETWORKS_INT32);
            }
            break;
          default:
            context->ReportError(context,
                                 "Unsupported type of pad value for pad_v2\n");
            return kTfLiteError;
        }
        continue;
      }

      if (input_index == kOptionalTensor &&
          (reg->builtin_code == kTfLiteBuiltinLstm ||
           reg->builtin_code == kTfLiteBuiltinSvdf ||
           reg->builtin_code == kTfLiteBuiltinBidirectionalSequenceLstm)) {
        // properly handle the optional tensor for LSTM and SVDF.
        // currently only support float32.
        TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0));
      } else if (reg->builtin_code == kTfLiteBuiltinResizeBilinear ||
                 reg->builtin_code == kTfLiteBuiltinResizeNearestNeighbor) {
        if (input_pos == 0) {
          // Only the first input tensor is added. The second one,
          // specifying the output height and width, is not added and
          // instead the height and width will be added individually as
          // scalars by the mapping function returned by Map().
          TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op));
        }
      } else if (reg->builtin_code == kTfLiteBuiltinTopkV2 && input_pos > 0) {
        // The K parameter tensor is not handled here but by the functor
        // returned by Map, the input tensor is instead added in
        // the else clause below
        continue;
      } else if (reg->builtin_code == kTfLiteBuiltinGather) {
        // Everything is added during Map since input tensors
        // have different order.
        continue;
      } else if (reg->builtin_code == kTfLiteBuiltinExpandDims &&
                 input_pos == 1) {
        // The axis param is added during Map
        continue;
      } else if (reg->builtin_code == kTfLiteBuiltinBatchToSpaceNd &&
                 input_pos == 2) {
        // NNAPI does not support crops.
        // The Map fucntion will check if all crops are zero.
        continue;
      } else if (reg->builtin_code == kTfLiteBuiltinArgMin ||
                 reg->builtin_code == kTfLiteBuiltinArgMax) {
        // The first input tensor is added as is. The second one, specifying
        // the axis, needs to be converted to a scalar since TFLite uses a
        // tensor but NNAPI uses a scalar as the axis.
        if (input_pos == 0) {
          TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op));
        } else {
          const int axis_id = node->inputs->data[1];
          const TfLiteTensor& axis_tensor = context->tensors[axis_id];
          switch (axis_tensor.type) {
            case kTfLiteInt32:
              if (axis_tensor.allocation_type == kTfLiteMmapRo) {
                TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
                    static_cast<int32_t>(*axis_tensor.data.i32)));
              } else {
                TF_LITE_ENSURE_STATUS(
                    builder.AddSingleValueTensorAsScalarOperand(
                        axis_id, ANEURALNETWORKS_INT32));
              }
              break;
            case kTfLiteInt64:
              // Map() function already makes sure int64 input is constant.
              TF_LITE_ENSURE_STATUS(builder.AddScalarInt32Operand(
                  static_cast<int32_t>(*axis_tensor.data.i64)));
              break;
            default:
              return kTfLiteError;
          }
        }
      } else {
        TF_LITE_ENSURE_STATUS(
            builder.AddTensorInput(input_index, hybrid_op, input_tensor_flags));
      }
    }
    // Get op type and operands
    int nn_op_type = Map(context, reg->builtin_code, reg->version,
                         nnapi_->android_sdk_version, node,
                         /*is_accelerator_specified=*/nnapi_device_ != nullptr)(
        {context, &builder, node, &model_state_outputs_,
         &model_state_tfl_inputs_, &feedback_loops_});
    // Map outputs to NN API tensor indices.
    int output_tensor_flags = 0;
    if (need_int8_conversion) {
      output_tensor_flags |= NN_TENSOR_FLAG_INT8_CONVERSION;
    }
    for (int output_pos = 0; output_pos < node->outputs->size; ++output_pos) {
      const auto output_index = node->outputs->data[output_pos];

      // Outputs for  basic LSTM cell are set in the Map function since
      if (reg->builtin_code == kTfLiteBuiltinLstm && isLstmBasicKernel(node)) {
        continue;
      }

      TF_LITE_ENSURE_STATUS(
          builder.AddTensorOutput(output_index, output_tensor_flags));
    }

    // Dequantize operators may have to be added in case inputs are to be
    // floating-point.
    AddDequantizeOperatorsWhereNeeded(context, reg->builtin_code, node,
                                      &builder);

    builder.FinalizeAddOperation(nn_op_type);
  }
  return kTfLiteOk;
}