def cli_main()

in cp_examples/sip_finetune/train_sip.py [0:0]


def cli_main(args):
    # ------------
    # data
    # ------------
    train_transform_list = [
        transforms.Resize(args.im_size),
        transforms.CenterCrop(args.im_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        HistogramNormalize(),
        TensorToRGB(),
        RemapLabel(-1, args.uncertain_label),
        NanToInt(args.nan_label),
    ]
    val_transform_list = [
        transforms.Resize(args.im_size),
        transforms.CenterCrop(args.im_size),
        transforms.ToTensor(),
        HistogramNormalize(),
        TensorToRGB(),
        RemapLabel(-1, args.uncertain_label),
    ]
    data_module = XrayDataModule(
        dataset_name=args.dataset_name,
        dataset_dir=args.dataset_dir,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        train_transform=Compose(train_transform_list),
        val_transform=Compose(val_transform_list),
        test_transform=Compose(val_transform_list),
    )

    # ------------
    # model
    # ------------
    pos_weights = fetch_pos_weights(
        dataset_name=args.dataset_name,
        csv=data_module.train_dataset.csv,
        label_list=data_module.label_list,
        uncertain_label=args.uncertain_label,
        nan_label=args.nan_label,
    )
    model = SipModule(
        arch=args.arch,
        num_classes=len(data_module.label_list),
        pretrained_file=args.pretrained_file,
        label_list=data_module.label_list,
        val_pathology_list=args.val_pathology_list,
        learning_rate=args.learning_rate,
        pos_weights=pos_weights,
        epochs=args.max_epochs,
    )

    # ------------
    # training
    # ------------
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=data_module)