def train_custom_model()

in 10_mlops/train_on_vertexai.py [0:0]


def train_custom_model(data_set, timestamp, develop_mode, cpu_only_mode, tf_version, extra_args=None):
    # Set up training and deployment infra
    
    if cpu_only_mode:
        train_image='us-docker.pkg.dev/vertex-ai/training/tf-cpu.{}:latest'.format(tf_version)
        deploy_image='us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.{}:latest'.format(tf_version)
    else:
        train_image = "us-docker.pkg.dev/vertex-ai/training/tf-gpu.{}:latest".format(tf_version)
        deploy_image = "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.{}:latest".format(tf_version)

    # train
    model_display_name = '{}-{}'.format(ENDPOINT_NAME, timestamp)
    job = aiplatform.CustomTrainingJob(
        display_name='train-{}'.format(model_display_name),
        script_path="model.py",
        container_uri=train_image,
        requirements=['cloudml-hypertune'],  # any extra Python packages
        model_serving_container_image_uri=deploy_image
    )
    model_args = [
        '--bucket', BUCKET,
    ]
    if develop_mode:
        model_args += ['--develop']
    if extra_args:
        model_args += extra_args
    
    if cpu_only_mode:
        model = job.run(
            dataset=data_set,
            # See https://googleapis.dev/python/aiplatform/latest/aiplatform.html#
            predefined_split_column_name='data_split',
            model_display_name=model_display_name,
            args=model_args,
            replica_count=1,
            machine_type='n1-standard-4',
            sync=develop_mode
        )
    else:
        model = job.run(
            dataset=data_set,
            # See https://googleapis.dev/python/aiplatform/latest/aiplatform.html#
            predefined_split_column_name='data_split',
            model_display_name=model_display_name,
            args=model_args,
            replica_count=1,
            machine_type='n1-standard-4',
            # See https://cloud.google.com/vertex-ai/docs/general/locations#accelerators
            accelerator_type=aip.AcceleratorType.NVIDIA_TESLA_T4.name,
            accelerator_count=1,
            sync=develop_mode
        )
    return model