in ludwig/api.py [0:0]
def main(sys_argv):
parser = argparse.ArgumentParser(
description='This script tests ludwig APIs.'
)
parser.add_argument(
'-t',
'--test',
default='train',
choices=['train', 'train_online', 'predict'],
help='which test to run'
)
# ---------------
# Data parameters
# ---------------
parser.add_argument('--data_csv', help='input data CSV file')
parser.add_argument(
'--train_set_metadata_json',
help='input metadata JSON file'
)
# ----------------
# Model parameters
# ----------------
parser.add_argument('-m', '--model_path', help='model to load')
parser.add_argument(
'-md',
'--model_definition',
type=yaml.safe_load,
help='model definition'
)
# ------------------
# Generic parameters
# ------------------
parser.add_argument(
'-bs',
'--batch_size',
type=int,
default=128,
help='size of batches'
)
# ------------------
# Runtime parameters
# ------------------
parser.add_argument(
'-g',
'--gpus',
type=int,
default=None,
help='list of gpu to use'
)
parser.add_argument(
'-gml',
'--gpu_memory_limit',
type=int,
default=None,
help='maximum memory in MB to allocate per GPU device'
)
parser.add_argument(
'-dpt',
'--disable_parallel_threads',
action='store_false',
dest='allow_parallel_threads',
help='disable TensorFlow from using multithreading for reproducibility'
)
parser.add_argument(
'-dbg',
'--debug',
action='store_true',
default=False,
help='enables debugging mode'
)
parser.add_argument(
'-l',
'--logging_level',
default='info',
help='the level of logging to use',
choices=['critical', 'error', 'warning', 'info', 'debug', 'notset']
)
args = parser.parse_args(sys_argv)
args.logging_level = logging_level_registry[args.logging_level]
if args.test == 'train':
test_train(**vars(args))
elif args.test == 'train_online':
test_train_online(**vars(args))
elif args.test == 'predict':
test_predict(**vars(args))
else:
logger.info('Unsupported test type')