def cli_main()

in cp_examples/mip_finetune/train_mip.py [0:0]


def cli_main(args):
    # ------------
    # data
    # ------------
    train_transform_list = [
        transforms.Resize(args.im_size),
        transforms.CenterCrop(args.im_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        HistogramNormalize(),
        TensorToRGB(),
        RemapLabel(-1, args.uncertain_label),
        NanToInt(args.nan_label),
    ]
    val_transform_list = [
        transforms.Resize(args.im_size),
        transforms.CenterCrop(args.im_size),
        transforms.ToTensor(),
        HistogramNormalize(),
        TensorToRGB(),
        RemapLabel(-1, args.uncertain_label),
    ]
    data_module = create_data_module(train_transform_list, val_transform_list)

    # ------------
    # model
    # ------------
    pos_weights = fetch_pos_weights(
        csv=data_module.train_dataset.csv,
        label_list=data_module.label_list,
        uncertain_label=args.uncertain_label,
        nan_label=args.nan_label,
    )
    model = MIPModule(
        args,
        data_module.label_list,
        pos_weights,
    )

    # ------------
    # training
    # ------------
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=data_module)