in cp_examples/sip_finetune/train_sip.py [0:0]
def build_args(arg_defaults=None):
pl.seed_everything(1234)
data_config = Path.cwd() / "../../configs/data.yaml"
tmp = arg_defaults
arg_defaults = {
"accelerator": "ddp",
"batch_size": 32,
"max_epochs": 5,
"gpus": 1,
"num_workers": 10,
"callbacks": [],
}
if tmp is not None:
arg_defaults.update(tmp)
# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument("--im_size", default=224, type=int)
parser.add_argument("--uncertain_label", default=np.nan, type=float)
parser.add_argument("--nan_label", default=np.nan, type=float)
parser = pl.Trainer.add_argparse_args(parser)
parser = XrayDataModule.add_model_specific_args(parser)
parser = SipModule.add_model_specific_args(parser)
parser.set_defaults(**arg_defaults)
args = parser.parse_args()
if args.default_root_dir is None:
args.default_root_dir = Path.cwd()
if args.pretrained_file is None:
warn("Pretrained file not specified, training from scratch.")
else:
logging.info(f"Loading pretrained file from {args.pretrained_file}")
if args.dataset_dir is None:
with open(data_config, "r") as f:
paths = yaml.load(f, Loader=yaml.SafeLoader)["paths"]
if args.dataset_name == "nih":
args.dataset_dir = paths["nih"]
if args.dataset_name == "mimic":
args.dataset_dir = paths["mimic"]
elif args.dataset_name == "chexpert":
args.dataset_dir = paths["chexpert"]
elif args.dataset_name == "mimic-chexpert":
args.dataset_dir = [paths["chexpert"], paths["mimic"]]
else:
raise ValueError("Unrecognized path config.")
if args.dataset_name in ("chexpert", "mimic", "mimic-chexpert"):
args.val_pathology_list = [
"Atelectasis",
"Cardiomegaly",
"Consolidation",
"Edema",
"Pleural Effusion",
]
elif args.dataset_name == "nih":
args.val_pathology_list = [
"Atelectasis",
"Cardiomegaly",
"Consolidation",
"Edema",
"Effusion",
]
else:
raise ValueError("Unrecognized dataset.")
# ------------
# checkpoints
# ------------
checkpoint_dir = Path(args.default_root_dir) / "checkpoints"
if not checkpoint_dir.exists():
checkpoint_dir.mkdir(parents=True)
elif args.resume_from_checkpoint is None:
ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime)
if ckpt_list:
args.resume_from_checkpoint = str(ckpt_list[-1])
args.callbacks.append(
pl.callbacks.ModelCheckpoint(dirpath=checkpoint_dir, verbose=True)
)
return args