in tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java [942:1055]
private Set<TypeVariableName> buildInputsClass() {
TypeSpec.Builder inputsBuilder =
TypeSpec.classBuilder(INPUTS_CLASS_NAME).addModifiers(Modifier.PUBLIC, Modifier.STATIC);
MethodSpec.Builder ctor = MethodSpec.constructorBuilder().addModifiers(Modifier.PUBLIC);
ctor.addParameter(Names.GraphOperation, "op");
StringJoiner attrNames = new StringJoiner(", ");
Set<TypeVariableName> typeVars = new LinkedHashSet<>();
CodeBlock.Builder fieldInits = CodeBlock.builder();
fieldInits.addStatement("int inputIndex = 0");
// add the inputs as parameters, and add them to the op builder
for (ArgDef input : op.getInputArgList()) {
ResolvedType type = resolver.typeOf(input);
String name = getJavaName(input);
ApiDef.Arg argDef = argApis.get(input);
typeVars.addAll(type.findGenerics());
TypeName javaType = type.iterableIfIterable().javaType;
String description =
argDef.getDescription().isEmpty()
? String.format("The %s input", name)
: argDef.getDescription();
inputsBuilder.addField(
FieldSpec.builder(javaType, name, Modifier.PUBLIC, Modifier.FINAL)
.addJavadoc("$L", parseDocumentation(description))
.build());
if (type.iterable) {
String inputListLength = name + "Length";
fieldInits.addStatement(
"int $L = op.inputListLength($S)", inputListLength, input.getName());
fieldInits.addStatement(
"$L = $T.asList(($T) op.inputList(inputIndex, $L))",
name,
Names.Arrays,
ArrayTypeName.of(type.javaType),
inputListLength);
fieldInits.addStatement("inputIndex += $L", inputListLength);
} else {
fieldInits.addStatement("$L = ($T) op.input(inputIndex++)", name, javaType);
}
}
for (AttrDef attr : op.getAttrList()) {
ResolvedType type = resolver.typeOf(attr);
String name = getJavaName(attr);
if (type.attributeType != null) {
ApiDef.Attr apiAttr = attrApis.get(attr);
String description =
apiAttr.getDescription().isEmpty()
? String.format("The %s attribute", name)
: apiAttr.getDescription();
TypeName javaType = type.jniType;
if (type.iterable) {
javaType = ArrayTypeName.of(javaType);
}
attrNames.add(CodeBlock.of("$S", attr.getName()).toString());
inputsBuilder.addField(
FieldSpec.builder(javaType, name, Modifier.PUBLIC, Modifier.FINAL)
.addJavadoc("$L", description)
.build());
fieldInits.addStatement(
"$L = op.attributes().getAttr$L($S)",
name,
type.attributeType.getterName(type.iterable),
attr.getName());
}
}
List<TypeName> sharedTypeVars = new ArrayList<>();
for (TypeVariableName onClass : this.builder.typeVariables) {
if (typeVars.contains(onClass)) {
sharedTypeVars.add(onClass);
} else {
sharedTypeVars.add(WildcardTypeName.subtypeOf(TypeName.OBJECT));
}
}
TypeName outputClass = className();
if (!this.builder.typeVariables.isEmpty()) {
outputClass =
ParameterizedTypeName.get(
(ClassName) outputClass, sharedTypeVars.toArray(new TypeName[0]));
}
inputsBuilder.superclass(ParameterizedTypeName.get(Names.RawOpInputs, outputClass));
CodeBlock.Builder body = CodeBlock.builder();
body.addStatement(
"super(new $L(op), op, $T.asList($L))",
this.builder.typeVariables.isEmpty() ? className : className + "<>",
Names.Arrays,
attrNames.toString());
body.add(fieldInits.build());
ctor.addCode(body.build());
inputsBuilder.addMethod(ctor.build());
inputsBuilder.addTypeVariables(typeVars);
addInputsMetadataAnnotation(inputsBuilder);
this.builder.addType(inputsBuilder.build());
return typeVars;
}