bool GenerateWrapperMetadata()

in tensorflow_lite_support/codegen/android_java_generator.cc [462:606]


bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
                             ErrorReporter* err) {
  code_writer->Append(
      "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
  const auto class_block = AsBlock(code_writer, "public static class Metadata");
  for (const auto& tensor : model.inputs) {
    SetCodeWriterWithTensorInfo(code_writer, tensor);
    code_writer->Append(R"(private final int[] {{NAME}}Shape;
private final DataType {{NAME}}DataType;
private final QuantizationParams {{NAME}}QuantizationParams;)");
    if (tensor.normalization_unit >= 0) {
      code_writer->Append(R"(private final float[] {{NAME}}Mean;
private final float[] {{NAME}}Stddev;)");
    }
  }
  for (const auto& tensor : model.outputs) {
    SetCodeWriterWithTensorInfo(code_writer, tensor);
    code_writer->Append(R"(private final int[] {{NAME}}Shape;
private final DataType {{NAME}}DataType;
private final QuantizationParams {{NAME}}QuantizationParams;)");
    if (tensor.normalization_unit >= 0) {
      code_writer->Append(R"(private final float[] {{NAME}}Mean;
private final float[] {{NAME}}Stddev;)");
    }
    if (tensor.associated_axis_label_index >= 0 ||
        tensor.associated_value_label_index >= 0) {
      code_writer->Append("private final List<String> {{NAME}}Labels;");
    }
  }
  code_writer->NewLine();
  {
    const auto ctor_block = AsBlock(
        code_writer,
        "public Metadata(ByteBuffer buffer, Model model) throws IOException");
    code_writer->Append(
        "MetadataExtractor extractor = new MetadataExtractor(buffer);");
    for (int i = 0; i < model.inputs.size(); i++) {
      SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
      code_writer->SetTokenValue("ID", std::to_string(i));
      code_writer->Append(
          R"(Tensor {{NAME}}Tensor = model.getInputTensor({{ID}});
{{NAME}}Shape = {{NAME}}Tensor.shape();
{{NAME}}DataType = {{NAME}}Tensor.dataType();
{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
      if (model.inputs[i].normalization_unit >= 0) {
        code_writer->Append(
            R"(NormalizationOptions {{NAME}}NormalizationOptions =
    (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
{{NAME}}MeanBuffer.get({{NAME}}Mean);
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
      }
    }
    for (int i = 0; i < model.outputs.size(); i++) {
      SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
      code_writer->SetTokenValue("ID", std::to_string(i));
      code_writer->Append(
          R"(Tensor {{NAME}}Tensor = model.getOutputTensor({{ID}});
{{NAME}}Shape = {{NAME}}Tensor.shape();
{{NAME}}DataType = {{NAME}}Tensor.dataType();
{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
      if (model.outputs[i].normalization_unit >= 0) {
        code_writer->Append(
            R"(NormalizationOptions {{NAME}}NormalizationOptions =
    (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
{{NAME}}MeanBuffer.get({{NAME}}Mean);
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
      }
      if (model.outputs[i].associated_axis_label_index >= 0) {
        code_writer->Append(R"(String {{NAME}}LabelsFileName =
    extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name();
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
      } else if (model.outputs[i].associated_value_label_index >= 0) {
        code_writer->Append(R"(String {{NAME}}LabelsFileName =
    extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name();
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
      }
    }
  }
  for (const auto& tensor : model.inputs) {
    SetCodeWriterWithTensorInfo(code_writer, tensor);
    code_writer->Append(R"(
public int[] get{{NAME_U}}Shape() {
  return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
}

public DataType get{{NAME_U}}Type() {
  return {{NAME}}DataType;
}

public QuantizationParams get{{NAME_U}}QuantizationParams() {
  return {{NAME}}QuantizationParams;
})");
    if (tensor.normalization_unit >= 0) {
      code_writer->Append(R"(
public float[] get{{NAME_U}}Mean() {
  return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
}

public float[] get{{NAME_U}}Stddev() {
  return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
})");
    }
  }
  for (const auto& tensor : model.outputs) {
    SetCodeWriterWithTensorInfo(code_writer, tensor);
    code_writer->Append(R"(
public int[] get{{NAME_U}}Shape() {
  return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
}

public DataType get{{NAME_U}}Type() {
  return {{NAME}}DataType;
}

public QuantizationParams get{{NAME_U}}QuantizationParams() {
  return {{NAME}}QuantizationParams;
})");
    if (tensor.normalization_unit >= 0) {
      code_writer->Append(R"(
public float[] get{{NAME_U}}Mean() {
  return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
}

public float[] get{{NAME_U}}Stddev() {
  return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
})");
    }
    if (tensor.associated_axis_label_index >= 0 ||
        tensor.associated_value_label_index >= 0) {
      code_writer->Append(R"(
public List<String> get{{NAME_U}}Labels() {
  return {{NAME}}Labels;
})");
    }
  }
  return true;
}