in tensorflow_lite_support/codegen/android_java_generator.cc [376:460]
bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
auto class_block = AsBlock(code_writer, "public static class Outputs");
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
if (tensor.associated_axis_label_index >= 0) {
code_writer->Append("private final List<String> {{NAME}}Labels;");
}
code_writer->Append(
"private final {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
}
// Getters
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->NewLine();
if (tensor.associated_axis_label_index >= 0) {
if (tensor.content_type == "tensor") {
code_writer->Append(
R"(public List<Category> get{{NAME_U}}AsCategoryList() {
return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getCategoryList();
})");
} else { // image
err->Warning(
"Axis label for images is not supported. The labels will "
"be ignored.");
}
} else { // no label
code_writer->Append(
R"(public {{WRAPPER_TYPE}} get{{NAME_U}}As{{WRAPPER_TYPE}}() {
return postprocess{{NAME_U}}({{NAME}});
})");
}
}
code_writer->NewLine();
{
const auto ctor_block = AsBlock(
code_writer,
"Outputs(Metadata metadata, {{POSTPROCESSOR_TYPE_PARAM_LIST}})");
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
if (tensor.content_type == "image") {
code_writer->Append(
R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());
{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)");
} else { // FEATURE, UNKNOWN
code_writer->Append(
"{{NAME}} = "
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
"metadata.get{{NAME_U}}Type());");
}
if (tensor.associated_axis_label_index >= 0) {
code_writer->Append("{{NAME}}Labels = metadata.get{{NAME_U}}Labels();");
}
code_writer->Append(
"this.{{NAME}}Postprocessor = {{NAME}}Postprocessor;");
}
}
code_writer->NewLine();
{
const auto get_buffer_block =
AsBlock(code_writer, "Map<Integer, Object> getBuffer()");
code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();");
for (int i = 0; i < model.outputs.size(); i++) {
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
code_writer->SetTokenValue("ID", std::to_string(i));
code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());");
}
code_writer->Append("return outputs;");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->NewLine();
{
auto processor_block =
AsBlock(code_writer,
"private {{WRAPPER_TYPE}} "
"postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
code_writer->Append(
"return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});");
}
}
return true;
}