in fastmri_examples/varnet/varnet_reproduce_20201111/varnet_knee_leaderboard.py [0:0]
def build_args():
parser = ArgumentParser()
# basic args
path_config = pathlib.Path("../../fastmri_dirs.yaml")
backend = "ddp"
num_gpus = 32
batch_size = 1
# set defaults based on optional directory config
data_path = fetch_dir("knee_path", path_config)
default_root_dir = (
fetch_dir("log_path", path_config) / "varnet" / "knee_leaderboard"
)
# client arguments
parser.add_argument(
"--mode",
default="train",
choices=("train", "test"),
type=str,
help="Operation mode",
)
# data transform params
parser.add_argument(
"--mask_type",
choices=("random", "equispaced"),
default="equispaced",
type=str,
help="Type of k-space mask",
)
parser.add_argument(
"--center_fractions",
nargs="+",
default=[0.08, 0.04],
type=float,
help="Number of center lines to use in mask",
)
parser.add_argument(
"--accelerations",
nargs="+",
default=[4, 8],
type=int,
help="Acceleration rates to use for masks",
)
# data config
parser = FastMriDataModule.add_data_specific_args(parser)
parser.set_defaults(
data_path=data_path, # path to fastMRI data
mask_type="random", # VarNet uses equispaced mask
challenge="multicoil", # only multicoil implemented for VarNet
batch_size=batch_size, # number of samples per batch
test_path=None, # path for test split, overwrites data_path
)
# module config
parser = VarNetModule.add_model_specific_args(parser)
parser.set_defaults(
num_cascades=12, # number of unrolled iterations
pools=4, # number of pooling layers for U-Net
chans=18, # number of top-level channels for U-Net
sens_pools=4, # number of pooling layers for sense est. U-Net
sens_chans=8, # number of top-level channels for sense est. U-Net
lr=0.0003, # Adam learning rate
lr_step_size=40, # epoch at which to decrease learning rate
lr_gamma=0.1, # extent to which to decrease learning rate
weight_decay=0.0, # weight regularization strength
)
# trainer config
parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults(
gpus=num_gpus, # number of gpus to use
replace_sampler_ddp=False, # this is necessary for volume dispatch during val
accelerator=backend, # what distributed version to use
seed=42, # random seed
deterministic=True, # makes things slower, but deterministic
default_root_dir=default_root_dir, # directory for logs and checkpoints
max_epochs=50, # max number of epochs
)
args = parser.parse_args()
# configure checkpointing in checkpoint_dir
checkpoint_dir = args.default_root_dir / "checkpoints"
if not checkpoint_dir.exists():
checkpoint_dir.mkdir(parents=True)
args.checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.default_root_dir / "checkpoints",
verbose=True,
prefix="",
)
# set default checkpoint if one exists in our checkpoint directory
if args.resume_from_checkpoint is None:
ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime)
if ckpt_list:
args.resume_from_checkpoint = str(ckpt_list[-1])
return args