bool GenerateWrapperOutputs()

in tensorflow_lite_support/codegen/android_java_generator.cc [376:460]


bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
                            ErrorReporter* err) {
  code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
  auto class_block = AsBlock(code_writer, "public static class Outputs");
  for (const auto& tensor : model.outputs) {
    SetCodeWriterWithTensorInfo(code_writer, tensor);
    code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
    if (tensor.associated_axis_label_index >= 0) {
      code_writer->Append("private final List<String> {{NAME}}Labels;");
    }
    code_writer->Append(
        "private final {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
  }
  // Getters
  for (const auto& tensor : model.outputs) {
    SetCodeWriterWithTensorInfo(code_writer, tensor);
    code_writer->NewLine();
    if (tensor.associated_axis_label_index >= 0) {
      if (tensor.content_type == "tensor") {
        code_writer->Append(
            R"(public List<Category> get{{NAME_U}}AsCategoryList() {
  return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getCategoryList();
})");
      } else {  // image
        err->Warning(
            "Axis label for images is not supported. The labels will "
            "be ignored.");
      }
    } else {  // no label
      code_writer->Append(
          R"(public {{WRAPPER_TYPE}} get{{NAME_U}}As{{WRAPPER_TYPE}}() {
  return postprocess{{NAME_U}}({{NAME}});
})");
    }
  }
  code_writer->NewLine();
  {
    const auto ctor_block = AsBlock(
        code_writer,
        "Outputs(Metadata metadata, {{POSTPROCESSOR_TYPE_PARAM_LIST}})");
    for (const auto& tensor : model.outputs) {
      SetCodeWriterWithTensorInfo(code_writer, tensor);
      if (tensor.content_type == "image") {
        code_writer->Append(
            R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());
{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)");
      } else {  // FEATURE, UNKNOWN
        code_writer->Append(
            "{{NAME}} = "
            "TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
            "metadata.get{{NAME_U}}Type());");
      }
      if (tensor.associated_axis_label_index >= 0) {
        code_writer->Append("{{NAME}}Labels = metadata.get{{NAME_U}}Labels();");
      }
      code_writer->Append(
          "this.{{NAME}}Postprocessor = {{NAME}}Postprocessor;");
    }
  }
  code_writer->NewLine();
  {
    const auto get_buffer_block =
        AsBlock(code_writer, "Map<Integer, Object> getBuffer()");
    code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();");
    for (int i = 0; i < model.outputs.size(); i++) {
      SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
      code_writer->SetTokenValue("ID", std::to_string(i));
      code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());");
    }
    code_writer->Append("return outputs;");
  }
  for (const auto& tensor : model.outputs) {
    SetCodeWriterWithTensorInfo(code_writer, tensor);
    code_writer->NewLine();
    {
      auto processor_block =
          AsBlock(code_writer,
                  "private {{WRAPPER_TYPE}} "
                  "postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
      code_writer->Append(
          "return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});");
    }
  }
  return true;
}