in cp_examples/mip_finetune/train_mip.py [0:0]
def cli_main(args):
# ------------
# data
# ------------
train_transform_list = [
transforms.Resize(args.im_size),
transforms.CenterCrop(args.im_size),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
HistogramNormalize(),
TensorToRGB(),
RemapLabel(-1, args.uncertain_label),
NanToInt(args.nan_label),
]
val_transform_list = [
transforms.Resize(args.im_size),
transforms.CenterCrop(args.im_size),
transforms.ToTensor(),
HistogramNormalize(),
TensorToRGB(),
RemapLabel(-1, args.uncertain_label),
]
data_module = create_data_module(train_transform_list, val_transform_list)
# ------------
# model
# ------------
pos_weights = fetch_pos_weights(
csv=data_module.train_dataset.csv,
label_list=data_module.label_list,
uncertain_label=args.uncertain_label,
nan_label=args.nan_label,
)
model = MIPModule(
args,
data_module.label_list,
pos_weights,
)
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=data_module)