private Set buildInputsClass()

in tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java [942:1055]


  private Set<TypeVariableName> buildInputsClass() {
    TypeSpec.Builder inputsBuilder =
        TypeSpec.classBuilder(INPUTS_CLASS_NAME).addModifiers(Modifier.PUBLIC, Modifier.STATIC);
    MethodSpec.Builder ctor = MethodSpec.constructorBuilder().addModifiers(Modifier.PUBLIC);
    ctor.addParameter(Names.GraphOperation, "op");

    StringJoiner attrNames = new StringJoiner(", ");

    Set<TypeVariableName> typeVars = new LinkedHashSet<>();
    CodeBlock.Builder fieldInits = CodeBlock.builder();

    fieldInits.addStatement("int inputIndex = 0");

    // add the inputs as parameters, and add them to the op builder
    for (ArgDef input : op.getInputArgList()) {
      ResolvedType type = resolver.typeOf(input);
      String name = getJavaName(input);
      ApiDef.Arg argDef = argApis.get(input);

      typeVars.addAll(type.findGenerics());

      TypeName javaType = type.iterableIfIterable().javaType;

      String description =
          argDef.getDescription().isEmpty()
              ? String.format("The %s input", name)
              : argDef.getDescription();

      inputsBuilder.addField(
          FieldSpec.builder(javaType, name, Modifier.PUBLIC, Modifier.FINAL)
              .addJavadoc("$L", parseDocumentation(description))
              .build());

      if (type.iterable) {
        String inputListLength = name + "Length";
        fieldInits.addStatement(
            "int $L = op.inputListLength($S)", inputListLength, input.getName());
        fieldInits.addStatement(
            "$L = $T.asList(($T) op.inputList(inputIndex, $L))",
            name,
            Names.Arrays,
            ArrayTypeName.of(type.javaType),
            inputListLength);
        fieldInits.addStatement("inputIndex += $L", inputListLength);
      } else {
        fieldInits.addStatement("$L = ($T) op.input(inputIndex++)", name, javaType);
      }
    }

    for (AttrDef attr : op.getAttrList()) {
      ResolvedType type = resolver.typeOf(attr);
      String name = getJavaName(attr);

      if (type.attributeType != null) {

        ApiDef.Attr apiAttr = attrApis.get(attr);

        String description =
            apiAttr.getDescription().isEmpty()
                ? String.format("The %s attribute", name)
                : apiAttr.getDescription();

        TypeName javaType = type.jniType;
        if (type.iterable) {
          javaType = ArrayTypeName.of(javaType);
        }

        attrNames.add(CodeBlock.of("$S", attr.getName()).toString());
        inputsBuilder.addField(
            FieldSpec.builder(javaType, name, Modifier.PUBLIC, Modifier.FINAL)
                .addJavadoc("$L", description)
                .build());
        fieldInits.addStatement(
            "$L = op.attributes().getAttr$L($S)",
            name,
            type.attributeType.getterName(type.iterable),
            attr.getName());
      }
    }

    List<TypeName> sharedTypeVars = new ArrayList<>();
    for (TypeVariableName onClass : this.builder.typeVariables) {
      if (typeVars.contains(onClass)) {
        sharedTypeVars.add(onClass);
      } else {
        sharedTypeVars.add(WildcardTypeName.subtypeOf(TypeName.OBJECT));
      }
    }

    TypeName outputClass = className();
    if (!this.builder.typeVariables.isEmpty()) {
      outputClass =
          ParameterizedTypeName.get(
              (ClassName) outputClass, sharedTypeVars.toArray(new TypeName[0]));
    }

    inputsBuilder.superclass(ParameterizedTypeName.get(Names.RawOpInputs, outputClass));

    CodeBlock.Builder body = CodeBlock.builder();
    body.addStatement(
        "super(new $L(op), op, $T.asList($L))",
        this.builder.typeVariables.isEmpty() ? className : className + "<>",
        Names.Arrays,
        attrNames.toString());

    body.add(fieldInits.build());
    ctor.addCode(body.build());

    inputsBuilder.addMethod(ctor.build());
    inputsBuilder.addTypeVariables(typeVars);
    addInputsMetadataAnnotation(inputsBuilder);
    this.builder.addType(inputsBuilder.build());
    return typeVars;
  }