def gen_args()

in args.py [0:0]


def gen_args(desc):
    parser = argparse.ArgumentParser(description=desc)

    parser.add_argument(
        "--model",
        type=str,
        default="cpreresnet20",
        help="Which model architecture to use. One of crepresnet20, resnet18, vgg19",
    )

    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10",
        help="Which dataset to use. One of cifar10, imagenet",
    )

    parser.add_argument(
        "--imagenet_dir",
        type=str,
        help="The root directory to the ImageNet dataset",
    )

    parser.add_argument("--save_dir", type=str, help="Directory to save model")

    parser.add_argument(
        "--norm",
        type=str,
        help="Layer normalization type. One of BN, IN, GN",
    )

    parser.add_argument("--epochs", type=int, help="Number of training epochs")

    parser.add_argument(
        "--test_freq", type=int, help="Number of epochs between testing"
    )

    parser.add_argument(
        "--learning_rate", type=float, help="Optimizer learning rate"
    )

    parser.add_argument(
        "--batch_size", type=int, help="Training/test batch size"
    )

    parser.add_argument("--momentum", type=float, help="Optimizer momentum")

    parser.add_argument(
        "--weight_decay", type=float, help="L2 regularization parameter"
    )

    return parser