in tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java [492:707]
private void buildFactoryMethods() {
MethodSpec.Builder factoryBuilder =
MethodSpec.methodBuilder("create").addModifiers(Modifier.PUBLIC, Modifier.STATIC);
// the main creator will inherit any class type params
TypeName returnType = ClassName.get(fullPackage, className);
if (!typeParams.isEmpty()) {
returnType =
ParameterizedTypeName.get((ClassName) returnType, typeParams.toArray(new TypeName[0]));
}
factoryBuilder.returns(returnType);
AnnotationSpec.Builder endpointAnnotation =
AnnotationSpec.builder(Names.Endpoint).addMember("describeByClass", "true");
String methodName = GeneratorUtils.getOpMethodName(className);
if (methodName != null) {
endpointAnnotation.addMember("name", "$S", methodName);
}
factoryBuilder.addAnnotation(endpointAnnotation.build());
factoryBuilder.addJavadoc(
"Factory method to create a class wrapping a new $L operation.\n", op.getName());
// we're going to build the body as add arguments
CodeBlock.Builder body = CodeBlock.builder();
Map<String, CodeBlock> paramTags = new LinkedHashMap<>();
factoryBuilder.addParameter(ParameterSpec.builder(Names.Scope, "scope").build());
paramTags.put("scope", CodeBlock.of("current scope"));
Set<TypeVariableName> typeVars = new LinkedHashSet<>(typeParams);
body.addStatement(
"$T opBuilder = scope.opBuilder($L, $S)", Names.OperationBuilder, OP_NAME_FIELD, className);
List<String> functionArgs = new ArrayList<>();
List<String> iterableFunctionArgs = new ArrayList<>();
// add the inputs as parameters, and add them to the op builder
for (ArgDef input : op.getInputArgList()) {
ApiDef.Arg argDef = argApis.get(input);
ResolvedType type = resolver.typeOf(input);
String name = getJavaName(input);
if (type.javaType.equals(Names.ConcreteFunction)) {
if (type.iterable) {
iterableFunctionArgs.add(name);
} else {
functionArgs.add(name);
}
}
ParameterSpec.Builder param = ParameterSpec.builder(type.iterableIfIterable().javaType, name);
String description =
argDef.getDescription().isEmpty()
? String.format("The %s value", name)
: argDef.getDescription();
paramTags.put(name, CodeBlock.of("$L", parseDocumentation(description)));
factoryBuilder.addParameter(param.build());
typeVars.addAll(type.findGenerics());
if (type.iterable) {
body.addStatement("opBuilder.addInputList($T.asOutputs($L))", Names.Operands, name);
} else {
body.addStatement("opBuilder.addInput($L.asOutput())", name);
}
}
// add the required attribute params, and build the default type maps for use in the secondary
// factory
Map<AttrDef, TypeName> defaultTypes = new HashMap<>();
Map<String, TypeName> defaultTypeVars = new HashMap<>();
for (AttrDef attr : requiredAttributes) {
if (resolver.partOfInput(attr.getName())) {
continue;
}
ResolvedType type = resolver.typeOf(attr);
ApiDef.Attr apiAttr = attrApis.get(attr);
String javaName = getJavaName(attr);
if (type.javaType.equals(Names.ConcreteFunction)) {
if (type.iterable) {
iterableFunctionArgs.add(javaName);
} else {
functionArgs.add(javaName);
}
}
ParameterSpec.Builder builder =
ParameterSpec.builder(type.classIfGeneric().listIfIterable().javaType, getJavaName(attr));
String description =
apiAttr.getDescription().isEmpty()
? String.format("The value of the %s attribute", javaName)
: apiAttr.getDescription();
paramTags.put(javaName, CodeBlock.of("$L", parseDocumentation(description)));
typeVars.addAll(type.findGenerics());
factoryBuilder.addParameter(builder.build());
// we only add defaults for type variable arguments
if (attr.hasDefaultValue() && type.shouldWrapInClass()) {
TypeName defaultType = TypeResolver.forDataType(attr.getDefaultValue().getType());
if (!(defaultType instanceof WildcardTypeName) && defaultType != Names.TType) {
defaultTypes.put(attr, defaultType);
defaultTypeVars.put(((TypeVariableName) type.javaType).name, defaultType);
}
}
writeSetAttr(body, attr, type, false);
}
// TODO optional function attrs (there currently aren't any)
// add optional attributes
if (optionsClass != null || (isStateSubclass && statefulPair.hasOptionalAttrs())) {
ClassName optionsClassName;
if (isStateSubclass) {
optionsClassName = ClassName.get(fullPackage, statefulPair.selectorClassName, "Options");
} else {
optionsClassName = ClassName.get(fullPackage, className, "Options");
}
factoryBuilder.addParameter(
ParameterSpec.builder(ArrayTypeName.of(optionsClassName), "options").build());
paramTags.put("options", CodeBlock.of("$L", "carries optional attribute values"));
factoryBuilder.varargs();
body.beginControlFlow("if (options != null)");
body.beginControlFlow("for ($T opts : options)", optionsClassName);
for (AttrDef attr : optionalAttributes) {
String name = getJavaName(attr);
body.beginControlFlow("if (opts.$L != null)", name);
writeSetAttr(body, attr, null, true);
body.endControlFlow();
}
body.endControlFlow();
body.endControlFlow();
}
body.addStatement(
"return new $L(opBuilder.build())", typeParams.isEmpty() ? className : (className + "<>"));
if (isStateSelector) {
body.clear();
body.addStatement("boolean isStateful = false");
functionArgs.forEach(
arg -> {
body.beginControlFlow("if ($L.isStateful())", arg)
.addStatement("isStateful = true")
.endControlFlow();
});
iterableFunctionArgs.forEach(
arg -> {
body.beginControlFlow("if ($L.stream().anyMatch(x -> x.isStateful()))", arg)
.addStatement("isStateful = true")
.endControlFlow();
});
StringJoiner argList = new StringJoiner(", ");
factoryBuilder.parameters.forEach(x -> argList.add(x.name));
body.beginControlFlow("if (isStateful)")
.addStatement(
"return $T.create($L)",
ClassName.get(fullPackage, statefulPair.statefulClassName),
argList.toString())
.nextControlFlow("else")
.addStatement(
"return $T.create($L)",
ClassName.get(fullPackage, statefulPair.statelessClassName),
argList.toString())
.endControlFlow();
}
factoryBuilder.addCode(body.build());
paramTags.forEach(
(param, doc) -> {
String description = doc.toString();
if (description.isEmpty() || description.equals("\n")) {
factoryBuilder.addJavadoc("\n@param $L the $L property", param, param);
} else {
factoryBuilder.addJavadoc("\n@param $L $L", param, doc);
}
});
for (TypeVariableName typeVar : typeVars) {
factoryBuilder.addJavadoc(
"\n@param <"
+ typeVar.name
+ "> data type for {@code "
+ op.getName()
+ "} output and operands");
}
factoryBuilder.addTypeVariables(typeVars);
factoryBuilder.addJavadoc("\n@return a new instance of $L\n", className);
MethodSpec method = factoryBuilder.build();
builder.addMethod(method);
if (!defaultTypes.isEmpty()) {
buildSecondaryFactory(defaultTypes, defaultTypeVars, method, paramTags);
}
}