in get_training_params.py [0:0]
def gen_args_dict(args):
a_dict = {}
if args.epochs is not None:
a_dict["epochs"] = args.epochs
if args.test_freq is not None:
a_dict["test_freq"] = args.epochs
if args.learning_rate is not None:
a_dict["learning_rate"] = args.learning_rate
if args.batch_size is not None:
a_dict["batch_size"] = args.batch_size
if args.momentum is not None:
a_dict["momentum"] = args.momentum
if args.weight_decay is not None:
a_dict["weight_decay"] = args.weight_decay
if args.save_dir is not None:
a_dict["save_dir"] = args.save_dir
if args.dataset == "imagenet":
if args.imagenet_dir is not None:
a_dict["dataset_dir"] = args.imagenet_dir
else:
raise ValueError(
f"ImageNet data directory must be specified as --imagenet_dir <dir>"
)
return a_dict