in tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java [710:781]
private void buildSecondaryFactory(
Map<AttrDef, TypeName> defaultTypes,
Map<String, TypeName> defaultTypeVars,
MethodSpec mainFactory,
Map<String, CodeBlock> paramTags) {
MethodSpec.Builder factoryBuilder =
MethodSpec.methodBuilder(mainFactory.name)
.addModifiers(mainFactory.modifiers)
.returns(
ParameterizedTypeName.get(
ClassName.get(fullPackage, className),
typeParams.stream()
.map(x -> defaultTypeVars.getOrDefault(x.name, x))
.toArray(TypeName[]::new)));
factoryBuilder.addAnnotations(mainFactory.annotations);
factoryBuilder.addJavadoc(
"Factory method to create a class wrapping a new $L operation, with the default output types.\n",
op.getName());
CodeBlock.Builder body = CodeBlock.builder();
body.add("return create(");
Set<TypeVariableName> typeVars = new LinkedHashSet<>();
// we want to add all of the main factory's parameters except for those with defaults
// we just pass them through
boolean first = true;
for (ParameterSpec param : mainFactory.parameters) {
if (!first) {
body.add(", ");
}
AttrDef attr =
op.getAttrList().stream()
.filter(x -> getJavaName(x).equals(param.name))
.findFirst()
.orElse(null);
if (attr != null
&& resolver.typeOf(attr).shouldWrapInClass()
&& defaultTypes.containsKey(attr)) {
body.add("$T.class", defaultTypes.get(attr));
} else {
factoryBuilder.addParameter(param);
factoryBuilder.addJavadoc("\n@param $L $L", param.name, paramTags.get(param.name));
typeVars.addAll(new ResolvedType(param.type).findGenerics());
body.add("$L", param.name);
}
first = false;
}
body.add(");");
for (TypeVariableName typeVar : typeVars) {
factoryBuilder.addJavadoc(
"\n@param <"
+ typeVar.name
+ "> data type for {@code "
+ op.getName()
+ "} output and operands");
}
factoryBuilder.addJavadoc(
"\n@return a new instance of $L, with default output types", className);
if (!isStateSelector) {
factoryBuilder.addCode(body.build());
}
factoryBuilder.addTypeVariables(typeVars);
builder.addMethod(factoryBuilder.build());
}