def cli_collect_weights()

in ludwig/collect.py [0:0]


def cli_collect_weights(sys_argv):
    """Command Line Interface to collecting the weights for the model
    --m: Input model that is necessary to collect to the tensors, this is a
         required *option*
    --t: Tensors to collect
    --od: Output directory of the model, defaults to results
    --dbg: Debug if the model is to be started with python debugger
    --v: Verbose: Defines the logging level that the user will be exposed to
    """
    parser = argparse.ArgumentParser(
        description='This script loads a pretrained model '
                    'and uses it collect weights.',
        prog='ludwig collect_weights',
        usage='%(prog)s [options]'
    )

    # ----------------
    # Model parameters
    # ----------------
    parser.add_argument(
        '-m',
        '--model_path',
        help='model to load',
        required=True
    )
    parser.add_argument(
        '-t',
        '--tensors',
        help='tensors to collect',
        nargs='+',
        required=True
    )

    # -------------------------
    # Output results parameters
    # -------------------------
    parser.add_argument(
        '-od',
        '--output_directory',
        type=str,
        default='results',
        help='directory that contains the results'
    )

    # ------------------
    # Runtime parameters
    # ------------------
    parser.add_argument(
        '-dbg',
        '--debug',
        action='store_true',
        default=False,
        help='enables debugging mode'
    )
    parser.add_argument(
        '-l',
        '--logging_level',
        default='info',
        help='the level of logging to use',
        choices=['critical', 'error', 'warning', 'info', 'debug', 'notset']
    )

    args = parser.parse_args(sys_argv)

    logging.getLogger('ludwig').setLevel(
        logging_level_registry[args.logging_level]
    )
    global logger
    logger = logging.getLogger('ludwig.collect')

    print_ludwig('Collect Weights', LUDWIG_VERSION)

    collect_weights(**vars(args))