def cli_main()

in cp_examples/moco_pretrain/train_moco.py [0:0]


def cli_main(args):
    # ------------
    # data
    # ------------
    transform_list = [
        transforms.RandomResizedCrop(args.im_size, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        RandomGaussianBlur(),
        AddGaussianNoise(snr_range=(4, 8)),
        HistogramNormalize(),
        TensorToRGB(),
    ]
    data_module = XrayDataModule(
        dataset_name=args.dataset_name,
        dataset_dir=args.dataset_dir,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        use_two_images=True,
        train_transform=Compose(transform_list),
        val_transform=Compose(transform_list),
        test_transform=Compose(transform_list),
    )

    # ------------
    # model
    # ------------
    model = MoCoModule(
        arch=args.arch,
        feature_dim=args.feature_dim,
        queue_size=args.queue_size,
        use_mlp=args.use_mlp,
        learning_rate=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        epochs=args.max_epochs,
    )

    # ------------
    # training
    # ------------
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=data_module)