in args.py [0:0]
def gen_args(desc):
parser = argparse.ArgumentParser(description=desc)
parser.add_argument(
"--model",
type=str,
default="cpreresnet20",
help="Which model architecture to use. One of crepresnet20, resnet18, vgg19",
)
parser.add_argument(
"--dataset",
type=str,
default="cifar10",
help="Which dataset to use. One of cifar10, imagenet",
)
parser.add_argument(
"--imagenet_dir",
type=str,
help="The root directory to the ImageNet dataset",
)
parser.add_argument("--save_dir", type=str, help="Directory to save model")
parser.add_argument(
"--norm",
type=str,
help="Layer normalization type. One of BN, IN, GN",
)
parser.add_argument("--epochs", type=int, help="Number of training epochs")
parser.add_argument(
"--test_freq", type=int, help="Number of epochs between testing"
)
parser.add_argument(
"--learning_rate", type=float, help="Optimizer learning rate"
)
parser.add_argument(
"--batch_size", type=int, help="Training/test batch size"
)
parser.add_argument("--momentum", type=float, help="Optimizer momentum")
parser.add_argument(
"--weight_decay", type=float, help="L2 regularization parameter"
)
return parser