nubia/internal/typing/argparse.py (227 lines of code) (raw):
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import copy
import os
import shutil
import subprocess
import sys
from collections import defaultdict
from functools import partial
from typing import Any, Dict, List, Tuple
from nubia.internal.helpers import try_await # noqa F401
from nubia.internal.typing.builder import (
build_value,
get_dict_kv_arg_type_as_str,
get_list_arg_type_as_str,
)
from nubia.internal.typing.inspect import (
get_first_type_argument,
is_iterable_type,
is_mapping_type,
is_optional_type,
)
from . import command, inspect_object, transform_name
def create_subparser_class(opts_parser):
# This is a hack to add the main parser arguments to each subcommand in
# order to allow main parser arguments to be specified after the
# subcommand, e.g.
# my_prog status -t <tier> --atonce=10
#
# The rationale of the implementation chosen is to propagate mutually
# exclusive groups from main parser to subparsers. While it is possible
# to infer kwargs from main parser actions list then passing them to the
# add_argument() method for each subparser, it will make us lose any
# information about mutually exclusive groups.
class SubParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
kwargs["add_help"] = False
super(SubParser, self).__init__(*args, **kwargs)
self._copied_actions_fingerprints = set()
# Copy mutually exclusive groups first
self._copy_mutually_exclusive_groups()
# Obviously we care only about optionals
self._copy_optionals()
def _copy_action(self, action, group, default=argparse.SUPPRESS):
action_fingerprint = "".join(action.option_strings)
# Avoid adding same option twice
if action_fingerprint not in self._copied_actions_fingerprints:
# FIXME: this is a really, really bad idea
a = copy.copy(action)
# Avoid common arguments to be overridden by subnamespace
a.default = default
group._add_action(a)
self._copied_actions_fingerprints.add(action_fingerprint)
def _copy_mutually_exclusive_groups(self):
for mutex_group in opts_parser._mutually_exclusive_groups:
mutex_group_copy = self.add_mutually_exclusive_group(
required=mutex_group.required
)
for action in mutex_group._group_actions:
self._copy_action(action, mutex_group_copy)
def _copy_optionals(self):
for action in opts_parser._optionals._actions:
# Skip _SubParsersAction from main parser
if not isinstance(action, argparse._SubParsersAction):
self._copy_action(action, self._optionals)
return SubParser
def add_command(argparse_parser, function):
inspection = inspect_object(function)
if not inspection.command:
return add_command(argparse_parser, command(function))
parser = register_command(argparse_parser, inspection)
# put a back reference so we can find this function later on `find_command`
# used in testing
parser.__command = function
return parser
def register_command(argparse_parser, inspection):
_command = inspection.command
# auto wrap the function with @command in case its not wrapped into one
subparsers = _resolve_subparsers(argparse_parser)
subparser = subparsers.add_parser(
_command.name, aliases=_command.aliases, help=_command.help
)
# Exclusive arguments needs to be added to argparse's mutually exclusive
# groups
exclusive_args = _command.exclusive_arguments or []
mutually_exclusive_groups = defaultdict(subparser.add_mutually_exclusive_group)
for arg in inspection.arguments.values():
add_argument_args, add_argument_kwargs = _argument_to_argparse_input(arg)
groups = [group for group in exclusive_args if arg.name in group]
if not groups:
subparser.add_argument(*add_argument_args, **add_argument_kwargs)
elif len(groups) == 1:
me_group = mutually_exclusive_groups[groups[0]]
me_group.add_argument(*add_argument_args, **add_argument_kwargs)
elif len(groups) > 1:
msg = (
"Argument {} is present in more than one exclusive "
"group: {}. This should not be allowed by the @command "
"decorator".format(arg.name, groups)
)
raise ValueError(msg)
# if we are adding a super command then we need to create a sub parser for
# this
if len(inspection.subcommands) > 0:
subcommand_parsers = subparser.add_subparsers(
dest="_subcmd",
help=_command.help,
parser_class=create_subparser_class(subparser),
metavar="[subcommand]".format(_command.name),
)
subcommand_parsers.required = True
# recursively add sub-commands
for _, v in inspection.subcommands:
register_command(subcommand_parsers, v)
return subparser
def _resolve_subparsers(parser):
# a subparser resulting from parser.add_subparsers was inputted
if isinstance(parser, argparse._SubParsersAction):
subparsers = parser
# an actual parser was inputted
elif isinstance(parser, argparse.ArgumentParser):
# Unfortunately there is no method to get the current subparsers apart
# from reading the private property. Trying to call
# parser.add_subparsers a second time will result in a SystemExit error.
# Also when you call parser.add_subparsers you get an Action object,
# that is listed under parser._subparsers._actions.
# Argparse is a beautiful thing
if getattr(parser, "_subparsers", None):
subparsers = parser._subparsers._actions[-1]
else:
subparsers = parser.add_subparsers(dest="_cmd", help="Subcommand to run")
else:
raise ValueError(
"Expected an argparse.ArgumentParser or an "
"argparse._SubParsersAction as input"
)
return subparsers
def _argument_to_argparse_input(arg: "Any") -> "Tuple[List, Dict[str, Any]]":
add_argument_kwargs = {"help": arg.description}
if arg.positional:
add_argument_args = [arg.name]
if arg.extra_names:
msg = "Aliases are not yet supported for positional arguments @ {}".format(
arg.name
)
raise ValueError(msg)
if arg.default_value_set:
msg = (
"Positional arguments with default values are "
"not supported @ {}".format(arg.name)
)
raise ValueError(msg)
else:
add_argument_args = [
transform_argument_name(x) for x in ([arg.name] + arg.extra_names)
]
add_argument_kwargs["default"] = arg.default_value
add_argument_kwargs["required"] = not arg.default_value_set
argument_type = (
arg.type
if not is_optional_type(arg.type)
else get_first_type_argument(arg.type)
)
if argument_type in [int, float, str]:
add_argument_kwargs["type"] = argument_type
add_argument_kwargs["metavar"] = str(argument_type.__name__).upper()
elif argument_type == bool or arg.default_value is False:
add_argument_kwargs["action"] = "store_true"
elif arg.default_value is True:
add_argument_kwargs["action"] = "store_false"
elif is_mapping_type(argument_type):
add_argument_kwargs["type"] = _parse_dict(argument_type)
add_argument_kwargs["metavar"] = "DICT[{}: {}]".format(
*get_dict_kv_arg_type_as_str(argument_type)
)
elif is_iterable_type(argument_type):
add_argument_kwargs["type"] = get_first_type_argument(argument_type)
add_argument_kwargs["nargs"] = "+"
add_argument_kwargs["metavar"] = "{}".format(
get_list_arg_type_as_str(argument_type)
)
else:
add_argument_kwargs["type"] = argument_type
if arg.choices:
add_argument_kwargs["choices"] = arg.choices
add_argument_kwargs["metavar"] = "{{{}}}".format(
",".join(map(str, arg.choices))
)
if arg.positional and "metavar" in add_argument_kwargs:
add_argument_kwargs["metavar"] = "{}<{}>".format(
arg.name, add_argument_kwargs["metavar"]
)
return add_argument_args, add_argument_kwargs
def find_command(parser, parsed_args, curry_args=False):
subparsers = _resolve_subparsers(parser)
parser_map = dict(item for item in subparsers._name_parser_map.items())
parser = parser_map.get(parsed_args._cmd)
function = parser.__command if parser else None
if not function:
return None
if curry_args:
kwargs = get_arguments_for_command(function, parsed_args)
function = partial(function, **kwargs)
return function
def get_arguments_for_inspection(inspection, kwargs):
# map back from names or extra names given to arguments to the actual
# arguments taken by the function
names_to_args = {
transform_name(arg_obj.name, to_char="_"): arg_obj.arg
for arg, arg_obj in inspection.arguments.items()
}
names_to_args.update(
{
transform_name(extra_name, to_char="_"): arg_obj.arg
for arg, arg_obj in inspection.arguments.items()
for extra_name in arg_obj.extra_names
}
)
# disconsider _cmd as it is used to identify the function/parser, not the
# actual arguments
valid_args = set(map(lambda arg_obj: arg_obj.arg, inspection.arguments.values()))
# use the reverse map to convert the names used in parsing to the actual
# arguments used in the command function
# filter out any argument that is not accepted by this function.
kwargs = {
names_to_args.get(name, name): value
for name, value in kwargs.items()
if names_to_args.get(name, name) in valid_args
}
return kwargs
def get_arguments_for_command(function, parsed_args):
# map back from names or extra names given to arguments to the actual
# arguments taken by the function
inspection = inspect_object(function)
kwargs = dict(parsed_args._get_kwargs())
return get_arguments_for_inspection(inspection, kwargs)
def transform_argument_name(name):
"""
Similar to transform_name, this is specific to export argument names
for the cli mode. Single character friendly names are treated as flags and
have a single dash (-) instead of a double dash (--)
For instance:
__special__ => --special
_some_arg = --some-arg
_f => -f
"""
name = transform_name(name)
return "--{}".format(name) if len(name) > 1 else "-{}".format(name)
def _parse_dict(target_type):
def parse_dict_(value):
return build_value(value, target_type, python_syntax=False)
return parse_dict_
class NubiaHelpAction(argparse.Action):
"""An action that pipes help message to the pager."""
def __init__(
self,
option_strings,
dest=argparse.SUPPRESS,
default=argparse.SUPPRESS,
help=None,
):
super(NubiaHelpAction, self).__init__(
option_strings=option_strings,
dest=dest,
default=default,
nargs=0,
help=help,
)
def __call__(self, parser, namespace, values, option_string=None):
help_message = parser.format_help()
help_message_length = len(help_message.split("\n"))
_, rows = shutil.get_terminal_size()
fits_one_page = help_message_length <= rows
if sys.stdout.isatty() and not fits_one_page:
pager = os.environ.get("PAGER", "less")
subprocess.run([pager], input=help_message.encode())
else: # fallback
parser.print_help()
parser.exit()