in cp_examples/moco_pretrain/train_moco.py [0:0]
def cli_main(args):
# ------------
# data
# ------------
transform_list = [
transforms.RandomResizedCrop(args.im_size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
RandomGaussianBlur(),
AddGaussianNoise(snr_range=(4, 8)),
HistogramNormalize(),
TensorToRGB(),
]
data_module = XrayDataModule(
dataset_name=args.dataset_name,
dataset_dir=args.dataset_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
use_two_images=True,
train_transform=Compose(transform_list),
val_transform=Compose(transform_list),
test_transform=Compose(transform_list),
)
# ------------
# model
# ------------
model = MoCoModule(
arch=args.arch,
feature_dim=args.feature_dim,
queue_size=args.queue_size,
use_mlp=args.use_mlp,
learning_rate=args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
epochs=args.max_epochs,
)
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=data_module)