in cp_examples/sip_finetune/train_sip.py [0:0]
def cli_main(args):
# ------------
# data
# ------------
train_transform_list = [
transforms.Resize(args.im_size),
transforms.CenterCrop(args.im_size),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
HistogramNormalize(),
TensorToRGB(),
RemapLabel(-1, args.uncertain_label),
NanToInt(args.nan_label),
]
val_transform_list = [
transforms.Resize(args.im_size),
transforms.CenterCrop(args.im_size),
transforms.ToTensor(),
HistogramNormalize(),
TensorToRGB(),
RemapLabel(-1, args.uncertain_label),
]
data_module = XrayDataModule(
dataset_name=args.dataset_name,
dataset_dir=args.dataset_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
train_transform=Compose(train_transform_list),
val_transform=Compose(val_transform_list),
test_transform=Compose(val_transform_list),
)
# ------------
# model
# ------------
pos_weights = fetch_pos_weights(
dataset_name=args.dataset_name,
csv=data_module.train_dataset.csv,
label_list=data_module.label_list,
uncertain_label=args.uncertain_label,
nan_label=args.nan_label,
)
model = SipModule(
arch=args.arch,
num_classes=len(data_module.label_list),
pretrained_file=args.pretrained_file,
label_list=data_module.label_list,
val_pathology_list=args.val_pathology_list,
learning_rate=args.learning_rate,
pos_weights=pos_weights,
epochs=args.max_epochs,
)
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=data_module)