in fastmri_examples/unet/unet_reproduce_20201111.py/unet_brain_leaderboard.py [0:0]
def build_args():
parser = ArgumentParser()
# basic args
path_config = pathlib.Path("../../fastmri_dirs.yaml")
backend = "ddp"
num_gpus = 32
batch_size = 1
# set defaults based on optional directory config
data_path = fetch_dir("brain_path", path_config)
default_root_dir = fetch_dir("log_path", path_config) / "unet" / "brain_leaderboard"
# client arguments
parser.add_argument(
"--mode",
default="train",
choices=("train", "test"),
type=str,
help="Operation mode",
)
# data transform params
parser.add_argument(
"--mask_type",
choices=("random", "equispaced"),
default="equispaced",
type=str,
help="Type of k-space mask",
)
parser.add_argument(
"--center_fractions",
nargs="+",
default=[0.08, 0.04],
type=float,
help="Number of center lines to use in mask",
)
parser.add_argument(
"--accelerations",
nargs="+",
default=[4, 8],
type=int,
help="Acceleration rates to use for masks",
)
# data config with path to fastMRI data and batch size
parser = FastMriDataModule.add_data_specific_args(parser)
parser.set_defaults(
data_path=data_path, # path to fastMRI data
mask_type="equispaced", # equispaced for brain data
challenge="multicoil", # which challenge
batch_size=batch_size, # number of samples per batch
test_path=None, # path for test split, overwrites data_path
)
# module config
parser = UnetModule.add_model_specific_args(parser)
parser.set_defaults(
in_chans=1, # number of input channels to U-Net
out_chans=1, # number of output chanenls to U-Net
chans=256, # number of top-level U-Net channels
num_pool_layers=4, # number of U-Net pooling layers
drop_prob=0.0, # dropout probability
lr=0.001, # RMSProp learning rate
lr_step_size=40, # epoch at which to decrease learning rate
lr_gamma=0.1, # extent to which to decrease learning rate
weight_decay=0.0, # weight decay regularization strength
)
# trainer config
parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults(
gpus=num_gpus, # number of gpus to use
replace_sampler_ddp=False, # this is necessary for volume dispatch during val
accelerator=backend, # what distributed version to use
seed=42, # random seed
deterministic=True, # makes things slower, but deterministic
default_root_dir=default_root_dir, # directory for logs and checkpoints
max_epochs=50, # max number of epochs
)
args = parser.parse_args()
# configure checkpointing in checkpoint_dir
checkpoint_dir = args.default_root_dir / "checkpoints"
if not checkpoint_dir.exists():
checkpoint_dir.mkdir(parents=True)
args.checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.default_root_dir / "checkpoints",
verbose=True,
prefix="",
)
# set default checkpoint if one exists in our checkpoint directory
if args.resume_from_checkpoint is None:
ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime)
if ckpt_list:
args.resume_from_checkpoint = str(ckpt_list[-1])
return args