nubia/internal/typing/argparse.py (227 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # import argparse import copy import os import shutil import subprocess import sys from collections import defaultdict from functools import partial from typing import Any, Dict, List, Tuple from nubia.internal.helpers import try_await # noqa F401 from nubia.internal.typing.builder import ( build_value, get_dict_kv_arg_type_as_str, get_list_arg_type_as_str, ) from nubia.internal.typing.inspect import ( get_first_type_argument, is_iterable_type, is_mapping_type, is_optional_type, ) from . import command, inspect_object, transform_name def create_subparser_class(opts_parser): # This is a hack to add the main parser arguments to each subcommand in # order to allow main parser arguments to be specified after the # subcommand, e.g. # my_prog status -t <tier> --atonce=10 # # The rationale of the implementation chosen is to propagate mutually # exclusive groups from main parser to subparsers. While it is possible # to infer kwargs from main parser actions list then passing them to the # add_argument() method for each subparser, it will make us lose any # information about mutually exclusive groups. class SubParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): kwargs["add_help"] = False super(SubParser, self).__init__(*args, **kwargs) self._copied_actions_fingerprints = set() # Copy mutually exclusive groups first self._copy_mutually_exclusive_groups() # Obviously we care only about optionals self._copy_optionals() def _copy_action(self, action, group, default=argparse.SUPPRESS): action_fingerprint = "".join(action.option_strings) # Avoid adding same option twice if action_fingerprint not in self._copied_actions_fingerprints: # FIXME: this is a really, really bad idea a = copy.copy(action) # Avoid common arguments to be overridden by subnamespace a.default = default group._add_action(a) self._copied_actions_fingerprints.add(action_fingerprint) def _copy_mutually_exclusive_groups(self): for mutex_group in opts_parser._mutually_exclusive_groups: mutex_group_copy = self.add_mutually_exclusive_group( required=mutex_group.required ) for action in mutex_group._group_actions: self._copy_action(action, mutex_group_copy) def _copy_optionals(self): for action in opts_parser._optionals._actions: # Skip _SubParsersAction from main parser if not isinstance(action, argparse._SubParsersAction): self._copy_action(action, self._optionals) return SubParser def add_command(argparse_parser, function): inspection = inspect_object(function) if not inspection.command: return add_command(argparse_parser, command(function)) parser = register_command(argparse_parser, inspection) # put a back reference so we can find this function later on `find_command` # used in testing parser.__command = function return parser def register_command(argparse_parser, inspection): _command = inspection.command # auto wrap the function with @command in case its not wrapped into one subparsers = _resolve_subparsers(argparse_parser) subparser = subparsers.add_parser( _command.name, aliases=_command.aliases, help=_command.help ) # Exclusive arguments needs to be added to argparse's mutually exclusive # groups exclusive_args = _command.exclusive_arguments or [] mutually_exclusive_groups = defaultdict(subparser.add_mutually_exclusive_group) for arg in inspection.arguments.values(): add_argument_args, add_argument_kwargs = _argument_to_argparse_input(arg) groups = [group for group in exclusive_args if arg.name in group] if not groups: subparser.add_argument(*add_argument_args, **add_argument_kwargs) elif len(groups) == 1: me_group = mutually_exclusive_groups[groups[0]] me_group.add_argument(*add_argument_args, **add_argument_kwargs) elif len(groups) > 1: msg = ( "Argument {} is present in more than one exclusive " "group: {}. This should not be allowed by the @command " "decorator".format(arg.name, groups) ) raise ValueError(msg) # if we are adding a super command then we need to create a sub parser for # this if len(inspection.subcommands) > 0: subcommand_parsers = subparser.add_subparsers( dest="_subcmd", help=_command.help, parser_class=create_subparser_class(subparser), metavar="[subcommand]".format(_command.name), ) subcommand_parsers.required = True # recursively add sub-commands for _, v in inspection.subcommands: register_command(subcommand_parsers, v) return subparser def _resolve_subparsers(parser): # a subparser resulting from parser.add_subparsers was inputted if isinstance(parser, argparse._SubParsersAction): subparsers = parser # an actual parser was inputted elif isinstance(parser, argparse.ArgumentParser): # Unfortunately there is no method to get the current subparsers apart # from reading the private property. Trying to call # parser.add_subparsers a second time will result in a SystemExit error. # Also when you call parser.add_subparsers you get an Action object, # that is listed under parser._subparsers._actions. # Argparse is a beautiful thing if getattr(parser, "_subparsers", None): subparsers = parser._subparsers._actions[-1] else: subparsers = parser.add_subparsers(dest="_cmd", help="Subcommand to run") else: raise ValueError( "Expected an argparse.ArgumentParser or an " "argparse._SubParsersAction as input" ) return subparsers def _argument_to_argparse_input(arg: "Any") -> "Tuple[List, Dict[str, Any]]": add_argument_kwargs = {"help": arg.description} if arg.positional: add_argument_args = [arg.name] if arg.extra_names: msg = "Aliases are not yet supported for positional arguments @ {}".format( arg.name ) raise ValueError(msg) if arg.default_value_set: msg = ( "Positional arguments with default values are " "not supported @ {}".format(arg.name) ) raise ValueError(msg) else: add_argument_args = [ transform_argument_name(x) for x in ([arg.name] + arg.extra_names) ] add_argument_kwargs["default"] = arg.default_value add_argument_kwargs["required"] = not arg.default_value_set argument_type = ( arg.type if not is_optional_type(arg.type) else get_first_type_argument(arg.type) ) if argument_type in [int, float, str]: add_argument_kwargs["type"] = argument_type add_argument_kwargs["metavar"] = str(argument_type.__name__).upper() elif argument_type == bool or arg.default_value is False: add_argument_kwargs["action"] = "store_true" elif arg.default_value is True: add_argument_kwargs["action"] = "store_false" elif is_mapping_type(argument_type): add_argument_kwargs["type"] = _parse_dict(argument_type) add_argument_kwargs["metavar"] = "DICT[{}: {}]".format( *get_dict_kv_arg_type_as_str(argument_type) ) elif is_iterable_type(argument_type): add_argument_kwargs["type"] = get_first_type_argument(argument_type) add_argument_kwargs["nargs"] = "+" add_argument_kwargs["metavar"] = "{}".format( get_list_arg_type_as_str(argument_type) ) else: add_argument_kwargs["type"] = argument_type if arg.choices: add_argument_kwargs["choices"] = arg.choices add_argument_kwargs["metavar"] = "{{{}}}".format( ",".join(map(str, arg.choices)) ) if arg.positional and "metavar" in add_argument_kwargs: add_argument_kwargs["metavar"] = "{}<{}>".format( arg.name, add_argument_kwargs["metavar"] ) return add_argument_args, add_argument_kwargs def find_command(parser, parsed_args, curry_args=False): subparsers = _resolve_subparsers(parser) parser_map = dict(item for item in subparsers._name_parser_map.items()) parser = parser_map.get(parsed_args._cmd) function = parser.__command if parser else None if not function: return None if curry_args: kwargs = get_arguments_for_command(function, parsed_args) function = partial(function, **kwargs) return function def get_arguments_for_inspection(inspection, kwargs): # map back from names or extra names given to arguments to the actual # arguments taken by the function names_to_args = { transform_name(arg_obj.name, to_char="_"): arg_obj.arg for arg, arg_obj in inspection.arguments.items() } names_to_args.update( { transform_name(extra_name, to_char="_"): arg_obj.arg for arg, arg_obj in inspection.arguments.items() for extra_name in arg_obj.extra_names } ) # disconsider _cmd as it is used to identify the function/parser, not the # actual arguments valid_args = set(map(lambda arg_obj: arg_obj.arg, inspection.arguments.values())) # use the reverse map to convert the names used in parsing to the actual # arguments used in the command function # filter out any argument that is not accepted by this function. kwargs = { names_to_args.get(name, name): value for name, value in kwargs.items() if names_to_args.get(name, name) in valid_args } return kwargs def get_arguments_for_command(function, parsed_args): # map back from names or extra names given to arguments to the actual # arguments taken by the function inspection = inspect_object(function) kwargs = dict(parsed_args._get_kwargs()) return get_arguments_for_inspection(inspection, kwargs) def transform_argument_name(name): """ Similar to transform_name, this is specific to export argument names for the cli mode. Single character friendly names are treated as flags and have a single dash (-) instead of a double dash (--) For instance: __special__ => --special _some_arg = --some-arg _f => -f """ name = transform_name(name) return "--{}".format(name) if len(name) > 1 else "-{}".format(name) def _parse_dict(target_type): def parse_dict_(value): return build_value(value, target_type, python_syntax=False) return parse_dict_ class NubiaHelpAction(argparse.Action): """An action that pipes help message to the pager.""" def __init__( self, option_strings, dest=argparse.SUPPRESS, default=argparse.SUPPRESS, help=None, ): super(NubiaHelpAction, self).__init__( option_strings=option_strings, dest=dest, default=default, nargs=0, help=help, ) def __call__(self, parser, namespace, values, option_string=None): help_message = parser.format_help() help_message_length = len(help_message.split("\n")) _, rows = shutil.get_terminal_size() fits_one_page = help_message_length <= rows if sys.stdout.isatty() and not fits_one_page: pager = os.environ.get("PAGER", "less") subprocess.run([pager], input=help_message.encode()) else: # fallback parser.print_help() parser.exit()