private void buildFactoryMethods()

in tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java [492:707]


  private void buildFactoryMethods() {
    MethodSpec.Builder factoryBuilder =
        MethodSpec.methodBuilder("create").addModifiers(Modifier.PUBLIC, Modifier.STATIC);

    // the main creator will inherit any class type params
    TypeName returnType = ClassName.get(fullPackage, className);
    if (!typeParams.isEmpty()) {
      returnType =
          ParameterizedTypeName.get((ClassName) returnType, typeParams.toArray(new TypeName[0]));
    }
    factoryBuilder.returns(returnType);

    AnnotationSpec.Builder endpointAnnotation =
        AnnotationSpec.builder(Names.Endpoint).addMember("describeByClass", "true");

    String methodName = GeneratorUtils.getOpMethodName(className);

    if (methodName != null) {
      endpointAnnotation.addMember("name", "$S", methodName);
    }

    factoryBuilder.addAnnotation(endpointAnnotation.build());

    factoryBuilder.addJavadoc(
        "Factory method to create a class wrapping a new $L operation.\n", op.getName());

    // we're going to build the body as add arguments
    CodeBlock.Builder body = CodeBlock.builder();

    Map<String, CodeBlock> paramTags = new LinkedHashMap<>();

    factoryBuilder.addParameter(ParameterSpec.builder(Names.Scope, "scope").build());
    paramTags.put("scope", CodeBlock.of("current scope"));

    Set<TypeVariableName> typeVars = new LinkedHashSet<>(typeParams);

    body.addStatement(
        "$T opBuilder = scope.opBuilder($L, $S)", Names.OperationBuilder, OP_NAME_FIELD, className);

    List<String> functionArgs = new ArrayList<>();
    List<String> iterableFunctionArgs = new ArrayList<>();

    // add the inputs as parameters, and add them to the op builder
    for (ArgDef input : op.getInputArgList()) {
      ApiDef.Arg argDef = argApis.get(input);
      ResolvedType type = resolver.typeOf(input);
      String name = getJavaName(input);

      if (type.javaType.equals(Names.ConcreteFunction)) {
        if (type.iterable) {
          iterableFunctionArgs.add(name);
        } else {
          functionArgs.add(name);
        }
      }

      ParameterSpec.Builder param = ParameterSpec.builder(type.iterableIfIterable().javaType, name);
      String description =
          argDef.getDescription().isEmpty()
              ? String.format("The %s value", name)
              : argDef.getDescription();
      paramTags.put(name, CodeBlock.of("$L", parseDocumentation(description)));
      factoryBuilder.addParameter(param.build());

      typeVars.addAll(type.findGenerics());

      if (type.iterable) {
        body.addStatement("opBuilder.addInputList($T.asOutputs($L))", Names.Operands, name);
      } else {
        body.addStatement("opBuilder.addInput($L.asOutput())", name);
      }
    }

    // add the required attribute params, and build the default type maps for use in the secondary
    // factory
    Map<AttrDef, TypeName> defaultTypes = new HashMap<>();
    Map<String, TypeName> defaultTypeVars = new HashMap<>();
    for (AttrDef attr : requiredAttributes) {
      if (resolver.partOfInput(attr.getName())) {
        continue;
      }

      ResolvedType type = resolver.typeOf(attr);
      ApiDef.Attr apiAttr = attrApis.get(attr);
      String javaName = getJavaName(attr);

      if (type.javaType.equals(Names.ConcreteFunction)) {
        if (type.iterable) {
          iterableFunctionArgs.add(javaName);
        } else {
          functionArgs.add(javaName);
        }
      }

      ParameterSpec.Builder builder =
          ParameterSpec.builder(type.classIfGeneric().listIfIterable().javaType, getJavaName(attr));

      String description =
          apiAttr.getDescription().isEmpty()
              ? String.format("The value of the %s attribute", javaName)
              : apiAttr.getDescription();
      paramTags.put(javaName, CodeBlock.of("$L", parseDocumentation(description)));

      typeVars.addAll(type.findGenerics());

      factoryBuilder.addParameter(builder.build());

      // we only add defaults for type variable arguments
      if (attr.hasDefaultValue() && type.shouldWrapInClass()) {
        TypeName defaultType = TypeResolver.forDataType(attr.getDefaultValue().getType());
        if (!(defaultType instanceof WildcardTypeName) && defaultType != Names.TType) {
          defaultTypes.put(attr, defaultType);
          defaultTypeVars.put(((TypeVariableName) type.javaType).name, defaultType);
        }
      }

      writeSetAttr(body, attr, type, false);
    }

    // TODO optional function attrs (there currently aren't any)

    // add optional attributes
    if (optionsClass != null || (isStateSubclass && statefulPair.hasOptionalAttrs())) {

      ClassName optionsClassName;
      if (isStateSubclass) {
        optionsClassName = ClassName.get(fullPackage, statefulPair.selectorClassName, "Options");
      } else {
        optionsClassName = ClassName.get(fullPackage, className, "Options");
      }

      factoryBuilder.addParameter(
          ParameterSpec.builder(ArrayTypeName.of(optionsClassName), "options").build());
      paramTags.put("options", CodeBlock.of("$L", "carries optional attribute values"));
      factoryBuilder.varargs();

      body.beginControlFlow("if (options != null)");

      body.beginControlFlow("for ($T opts : options)", optionsClassName);
      for (AttrDef attr : optionalAttributes) {
        String name = getJavaName(attr);
        body.beginControlFlow("if (opts.$L != null)", name);

        writeSetAttr(body, attr, null, true);

        body.endControlFlow();
      }
      body.endControlFlow();

      body.endControlFlow();
    }

    body.addStatement(
        "return new $L(opBuilder.build())", typeParams.isEmpty() ? className : (className + "<>"));

    if (isStateSelector) {
      body.clear();

      body.addStatement("boolean isStateful = false");
      functionArgs.forEach(
          arg -> {
            body.beginControlFlow("if ($L.isStateful())", arg)
                .addStatement("isStateful = true")
                .endControlFlow();
          });
      iterableFunctionArgs.forEach(
          arg -> {
            body.beginControlFlow("if ($L.stream().anyMatch(x -> x.isStateful()))", arg)
                .addStatement("isStateful = true")
                .endControlFlow();
          });

      StringJoiner argList = new StringJoiner(", ");
      factoryBuilder.parameters.forEach(x -> argList.add(x.name));

      body.beginControlFlow("if (isStateful)")
          .addStatement(
              "return $T.create($L)",
              ClassName.get(fullPackage, statefulPair.statefulClassName),
              argList.toString())
          .nextControlFlow("else")
          .addStatement(
              "return $T.create($L)",
              ClassName.get(fullPackage, statefulPair.statelessClassName),
              argList.toString())
          .endControlFlow();
    }
    factoryBuilder.addCode(body.build());

    paramTags.forEach(
        (param, doc) -> {
          String description = doc.toString();
          if (description.isEmpty() || description.equals("\n")) {
            factoryBuilder.addJavadoc("\n@param $L the $L property", param, param);
          } else {
            factoryBuilder.addJavadoc("\n@param $L $L", param, doc);
          }
        });
    for (TypeVariableName typeVar : typeVars) {
      factoryBuilder.addJavadoc(
          "\n@param <"
              + typeVar.name
              + "> data type for {@code "
              + op.getName()
              + "} output and operands");
    }

    factoryBuilder.addTypeVariables(typeVars);
    factoryBuilder.addJavadoc("\n@return a new instance of $L\n", className);
    MethodSpec method = factoryBuilder.build();
    builder.addMethod(method);

    if (!defaultTypes.isEmpty()) {
      buildSecondaryFactory(defaultTypes, defaultTypeVars, method, paramTags);
    }
  }