in models/bmshj2018.py [0:0]
def parse_args(argv):
"""Parses command line arguments."""
parser = argparse_flags.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# High-level options.
parser.add_argument(
"--verbose", "-V", action="store_true",
help="Report progress and metrics when training or compressing.")
parser.add_argument(
"--model_path", default="bmshj2018",
help="Path where to save/load the trained model.")
subparsers = parser.add_subparsers(
title="commands", dest="command",
help="What to do: 'train' loads training data and trains (or continues "
"to train) a new model. 'compress' reads an image file (lossless "
"PNG format) and writes a compressed binary file. 'decompress' "
"reads a binary file and reconstructs the image (in PNG format). "
"input and output filenames need to be provided for the latter "
"two options. Invoke '<command> -h' for more information.")
# 'train' subcommand.
train_cmd = subparsers.add_parser(
"train",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Trains (or continues to train) a new model. Note that this "
"model trains on a continuous stream of patches drawn from "
"the training image dataset. An epoch is always defined as "
"the same number of batches given by --steps_per_epoch. "
"The purpose of validation is mostly to evaluate the "
"rate-distortion performance of the model using actual "
"quantization rather than the differentiable proxy loss. "
"Note that when using custom training images, the validation "
"set is simply a random sampling of patches from the "
"training set.")
train_cmd.add_argument(
"--lambda", type=float, default=0.01, dest="lmbda",
help="Lambda for rate-distortion tradeoff.")
train_cmd.add_argument(
"--train_glob", type=str, default=None,
help="Glob pattern identifying custom training data. This pattern must "
"expand to a list of RGB images in PNG format. If unspecified, the "
"CLIC dataset from TensorFlow Datasets is used.")
train_cmd.add_argument(
"--num_filters", type=int, default=192,
help="Number of filters per layer.")
train_cmd.add_argument(
"--num_scales", type=int, default=64,
help="Number of Gaussian scales to prepare range coding tables for.")
train_cmd.add_argument(
"--scale_min", type=float, default=.11,
help="Minimum value of standard deviation of Gaussians.")
train_cmd.add_argument(
"--scale_max", type=float, default=256.,
help="Maximum value of standard deviation of Gaussians.")
train_cmd.add_argument(
"--train_path", default="/tmp/train_bmshj2018",
help="Path where to log training metrics for TensorBoard and back up "
"intermediate model checkpoints.")
train_cmd.add_argument(
"--batchsize", type=int, default=8,
help="Batch size for training and validation.")
train_cmd.add_argument(
"--patchsize", type=int, default=256,
help="Size of image patches for training and validation.")
train_cmd.add_argument(
"--epochs", type=int, default=1000,
help="Train up to this number of epochs. (One epoch is here defined as "
"the number of steps given by --steps_per_epoch, not iterations "
"over the full training dataset.)")
train_cmd.add_argument(
"--steps_per_epoch", type=int, default=1000,
help="Perform validation and produce logs after this many batches.")
train_cmd.add_argument(
"--max_validation_steps", type=int, default=16,
help="Maximum number of batches to use for validation. If -1, use one "
"patch from each image in the training set.")
train_cmd.add_argument(
"--preprocess_threads", type=int, default=16,
help="Number of CPU threads to use for parallel decoding of training "
"images.")
train_cmd.add_argument(
"--check_numerics", action="store_true",
help="Enable TF support for catching NaN and Inf in tensors.")
# 'compress' subcommand.
compress_cmd = subparsers.add_parser(
"compress",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Reads a PNG file, compresses it, and writes a TFCI file.")
# 'decompress' subcommand.
decompress_cmd = subparsers.add_parser(
"decompress",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Reads a TFCI file, reconstructs the image, and writes back "
"a PNG file.")
# Arguments for both 'compress' and 'decompress'.
for cmd, ext in ((compress_cmd, ".tfci"), (decompress_cmd, ".png")):
cmd.add_argument(
"input_file",
help="Input filename.")
cmd.add_argument(
"output_file", nargs="?",
help=f"Output filename (optional). If not provided, appends '{ext}' to "
f"the input filename.")
# Parse arguments.
args = parser.parse_args(argv[1:])
if args.command is None:
parser.print_usage()
sys.exit(2)
return args