in tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java [237:365]
void buildClass() {
builder.addModifiers(Modifier.PUBLIC);
if (!isStateSelector) {
builder.addModifiers(Modifier.FINAL);
addInputsMetadataAnnotation();
}
if (isStateSubclass) {
builder.addSuperinterface(ClassName.get(fullPackage, statefulPair.selectorClassName));
}
// add class javadocs
String summary = parseDocumentation(apiDef.getSummary());
if (!summary.isEmpty()) {
builder.addJavadoc("$L", summary + "\n");
} else {
builder.addJavadoc("The $L operation\n", apiDef.getGraphOpName());
}
String desc = parseDocumentation(apiDef.getDescription());
if (!desc.isEmpty()) {
builder.addJavadoc("$L", desc + "\n");
}
if (isStateSelector) {
builder.addJavadoc("\n<p>");
builder.addJavadoc(
"Selects between {@link "
+ statefulPair.statefulClassName
+ "} and {@link "
+ statefulPair.statelessClassName
+ "} based on the statefulness of the function arguments.");
}
// add superinterface and set mode
if (op.getOutputArgCount() == 1) {
ArgDef output = op.getOutputArg(0);
ResolvedType rType = resolver.typeOf(output);
TypeName type = rType.unwrapArg();
boolean iterable = rType.iterable;
TypeName operandTypeParam = type instanceof WildcardTypeName ? Names.TType : type;
TypeName operandType = ParameterizedTypeName.get(Names.Operand, operandTypeParam);
if (iterable) {
mode = RenderMode.LIST_OPERAND;
if (!isStateSubclass) {
builder.addSuperinterface(
ParameterizedTypeName.get(ClassName.get(Iterable.class), operandType));
}
} else {
mode = RenderMode.OPERAND;
if (!isStateSubclass) {
builder.addSuperinterface(operandType);
}
}
}
// add and store type variables
Set<String> seenGenerics = new HashSet<>();
for (ArgDef output : op.getOutputArgList()) {
ResolvedType type = resolver.typeOf(output);
for (TypeVariableName typeVar : type.findGenerics()) {
if (seenGenerics.add(typeVar.name)) {
typeParams.add(typeVar);
builder.addTypeVariable(typeVar);
builder.addJavadoc(
"\n@param <$L> data type for {@code $L} output\n", typeVar.name, output.getName());
}
}
}
// add deprecated if necessary
if (endpoint.getDeprecated()) {
builder.addAnnotation(Deprecated.class);
Endpoint first = apiDef.getEndpoint(0);
String explanation;
if (!first.getDeprecated()) {
explanation = "use {@link " + basePackage + "." + first.getName() + "} instead";
} else {
explanation = op.getDeprecation().getExplanation();
}
builder.addJavadoc("\n@deprecated $L", explanation);
}
// add the Operator annotation
if (apiDef.getVisibility() != Visibility.HIDDEN) {
AnnotationSpec.Builder annotation = AnnotationSpec.builder(Names.Operator);
if (!group.equals("core")) {
annotation.addMember("group", "$S", group);
}
builder.addAnnotation(annotation.build());
}
if (!optionalAttributes.isEmpty() && !isStateSubclass) {
buildOptionsClass();
}
buildFactoryMethods();
buildGettersAndSetters();
if (mode != RenderMode.DEFAULT) {
buildInterfaceImpl();
}
if (!isStateSelector) { // add op name field
builder.addField(
FieldSpec.builder(
TypeResolver.STRING,
OP_NAME_FIELD,
Modifier.PUBLIC,
Modifier.STATIC,
Modifier.FINAL)
.addJavadoc("$L", "The name of this op, as known by TensorFlow core engine")
.initializer("$S", op.getName())
.build());
// add output fields
if (op.getOutputArgCount() > 0) {
for (ArgDef output : op.getOutputArgList()) {
builder.addField(
resolver.typeOf(output).listIfIterable().javaType,
getJavaName(output),
Modifier.PRIVATE);
}
}
buildConstructor();
buildInputsClass();
builder.superclass(Names.RawOp);
}
}