static void createHyperparameterTuningJobPythonPackageSample()

in aiplatform/src/main/java/aiplatform/CreateHyperparameterTuningJobPythonPackageSample.java [54:171]


  static void createHyperparameterTuningJobPythonPackageSample(
      String project,
      String displayName,
      String executorImageUri,
      String packageUri,
      String pythonModule)
      throws IOException {
    JobServiceSettings settings =
        JobServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();
    String location = "us-central1";

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (JobServiceClient client = JobServiceClient.create(settings)) {
      // study spec
      MetricSpec metric =
          MetricSpec.newBuilder().setMetricId("val_rmse").setGoal(GoalType.MINIMIZE).build();

      // decay
      DoubleValueSpec doubleValueSpec =
          DoubleValueSpec.newBuilder().setMinValue(1e-07).setMaxValue(1).build();
      ParameterSpec parameterDecaySpec =
          ParameterSpec.newBuilder()
              .setParameterId("decay")
              .setDoubleValueSpec(doubleValueSpec)
              .setScaleType(ScaleType.UNIT_LINEAR_SCALE)
              .build();
      Double[] decayValues = {32.0, 64.0};
      DiscreteValueCondition discreteValueDecay =
          DiscreteValueCondition.newBuilder().addAllValues(Arrays.asList(decayValues)).build();
      ConditionalParameterSpec conditionalParameterDecay =
          ConditionalParameterSpec.newBuilder()
              .setParameterSpec(parameterDecaySpec)
              .setParentDiscreteValues(discreteValueDecay)
              .build();

      // learning rate
      ParameterSpec parameterLearningSpec =
          ParameterSpec.newBuilder()
              .setParameterId("learning_rate")
              .setDoubleValueSpec(doubleValueSpec) // Use the same min/max as for decay
              .setScaleType(ScaleType.UNIT_LINEAR_SCALE)
              .build();

      Double[] learningRateValues = {4.0, 8.0, 16.0};
      DiscreteValueCondition discreteValueLearning =
          DiscreteValueCondition.newBuilder()
              .addAllValues(Arrays.asList(learningRateValues))
              .build();
      ConditionalParameterSpec conditionalParameterLearning =
          ConditionalParameterSpec.newBuilder()
              .setParameterSpec(parameterLearningSpec)
              .setParentDiscreteValues(discreteValueLearning)
              .build();

      // batch size
      Double[] batchSizeValues = {4.0, 8.0, 16.0, 32.0, 64.0, 128.0};

      DiscreteValueSpec discreteValueSpec =
          DiscreteValueSpec.newBuilder().addAllValues(Arrays.asList(batchSizeValues)).build();
      ParameterSpec parameter =
          ParameterSpec.newBuilder()
              .setParameterId("batch_size")
              .setDiscreteValueSpec(discreteValueSpec)
              .setScaleType(ScaleType.UNIT_LINEAR_SCALE)
              .addConditionalParameterSpecs(conditionalParameterDecay)
              .addConditionalParameterSpecs(conditionalParameterLearning)
              .build();

      // trial_job_spec
      MachineSpec machineSpec =
          MachineSpec.newBuilder()
              .setMachineType("n1-standard-4")
              .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_T4)
              .setAcceleratorCount(1)
              .build();

      PythonPackageSpec pythonPackageSpec =
          PythonPackageSpec.newBuilder()
              .setExecutorImageUri(executorImageUri)
              .addPackageUris(packageUri)
              .setPythonModule(pythonModule)
              .build();

      WorkerPoolSpec workerPoolSpec =
          WorkerPoolSpec.newBuilder()
              .setMachineSpec(machineSpec)
              .setReplicaCount(1)
              .setPythonPackageSpec(pythonPackageSpec)
              .build();

      StudySpec studySpec =
          StudySpec.newBuilder()
              .addMetrics(metric)
              .addParameters(parameter)
              .setAlgorithm(StudySpec.Algorithm.RANDOM_SEARCH)
              .build();
      CustomJobSpec trialJobSpec =
          CustomJobSpec.newBuilder().addWorkerPoolSpecs(workerPoolSpec).build();
      // hyperparameter_tuning_job
      HyperparameterTuningJob hyperparameterTuningJob =
          HyperparameterTuningJob.newBuilder()
              .setDisplayName(displayName)
              .setMaxTrialCount(4)
              .setParallelTrialCount(2)
              .setStudySpec(studySpec)
              .setTrialJobSpec(trialJobSpec)
              .build();
      LocationName parent = LocationName.of(project, location);
      HyperparameterTuningJob response =
          client.createHyperparameterTuningJob(parent, hyperparameterTuningJob);
      System.out.format("response: %s\n", response);
      System.out.format("Name: %s\n", response.getName());
    }
  }