def run()

in cli/jobs/pipelines/tensorflow-image-segmentation/src/run.py [0:0]


def run(args):
    """Run the script using CLI arguments"""
    logger = logging.getLogger(__name__)
    logger.info(f"Running with arguments: {args}")

    # MLFLOW: initialize mlflow (once in entire script)
    mlflow.start_run()

    # use a handler for the training sequence
    training_handler = TensorflowDistributedModelTrainingSequence()

    # sets cuda and distributed config
    training_handler.setup_config(args)

    # DATA
    with LogTimeBlock("build_image_datasets", enabled=True), LogDiskIOBlock(
        "build_image_datasets", enabled=True
    ):
        train_dataset_helper = ImageAndMaskSequenceDataset(
            images_dir=args.train_images,
            masks_dir=args.train_masks,
            images_filename_pattern=args.images_filename_pattern,
            masks_filename_pattern=args.masks_filename_pattern,
            images_type=args.images_type,
        )

        # the helper returns a dataset containing paths to images
        # and a loading function to use map() for loading the data
        train_dataset, train_loading_function = train_dataset_helper.dataset(
            input_size=args.model_input_size
        )

        test_dataset_helper = ImageAndMaskSequenceDataset(
            images_dir=args.test_images,
            masks_dir=args.test_masks,
            images_filename_pattern=args.images_filename_pattern,
            masks_filename_pattern=args.masks_filename_pattern,
            images_type="png",  # masks need to be in png
        )

        # the helper returns a dataset containing paths to images
        # and a loading function to use map() for loading the data
        val_dataset, val_loading_function = test_dataset_helper.dataset(
            input_size=args.model_input_size
        )

        training_handler.setup_datasets(
            train_dataset,
            train_loading_function,
            val_dataset,
            val_loading_function,
            training_dataset_length=len(
                train_dataset_helper
            ),  # used to shuffle and repeat dataset
        )

    # Free up RAM in case the model definition cells were run multiple times
    keras.backend.clear_session()

    # DISTRIBUTED: build model
    with LogTimeBlock("load_model", enabled=True):
        with training_handler.strategy.scope():
            model = load_model(
                model_arch=args.model_arch,
                input_size=args.model_input_size,
                num_classes=args.num_classes,
            )

            # print model summary to stdout
            model.summary()

            # Configure the model for training.
            # We use the "sparse" version of categorical_crossentropy
            # because our target data is integers.
            model.compile(
                optimizer=args.optimizer,
                loss=args.loss,
                metrics=["accuracy"],
                # run_eagerly=True
            )

    # sets the model for distributed training
    training_handler.setup_model(model)

    mlflow.log_metric("start_to_fit_time", time.time() - SCRIPT_START_TIME)

    # runs training sequence
    # NOTE: num_epochs is provided in args
    try:
        training_handler.train()  # TODO: checkpoints_dir=args.checkpoints)
    except RuntimeError as runtime_exception:  # if runtime error occurs (ex: cuda out of memory)
        # then print some runtime error report in the logs
        training_handler.runtime_error_report(runtime_exception)
        # re-raise
        raise runtime_exception

    # saves final model
    if args.model_output:
        training_handler.save(
            args.model_output,
            name=f"epoch-{args.num_epochs}",
            register_as=args.register_model_as,
        )

    # properly teardown distributed resources
    training_handler.close()

    # logging total time
    mlflow.log_metric("wall_time", time.time() - SCRIPT_START_TIME)

    # MLFLOW: finalize mlflow (once in entire script)
    mlflow.end_run()

    logger.info("run() completed")