testslide/cli.py (371 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import os import re import sys import unittest from contextlib import contextmanager from dataclasses import dataclass from time import time from typing import Any, Callable, Iterator, List, Optional, Pattern, Type import testslide.dsl from . import Context, TestCase, _TestSlideTestResult from .runner import DocumentFormatter, LongFormatter, ProgressFormatter, Runner from .strict_mock import StrictMock _unittest_testcase_loaded: bool = False def _filename_to_module_name(name: str) -> str: if not ( os.path.isfile(name) and (name.lower().endswith(".py") or name.lower().endswith(".pyc")) ): raise ValueError("Expected a .py file, got {}".format(name)) if os.path.isabs(name): name = os.path.relpath(name, os.getcwd()) if name.lower().endswith(".pyc"): end = -4 else: end = -3 return name[:end].replace(os.path.sep, ".") def _get_all_test_case_subclasses() -> List[TestCase]: def get_all_subclasses(base: Type[unittest.TestCase]) -> List[TestCase]: return list( { # type: ignore "{}.{}".format(c.__module__, c.__name__): c for c in ( base.__subclasses__() # type: ignore + [g for s in base.__subclasses__() for g in get_all_subclasses(s)] # type: ignore ) }.values() ) return get_all_subclasses(unittest.TestCase) def _get_all_test_cases(import_module_names: List[str]) -> List[TestCase]: if import_module_names: return [ test_case for test_case in _get_all_test_case_subclasses() if test_case.__module__ in import_module_names ] else: return _get_all_test_case_subclasses() def _load_unittest_test_cases(import_module_names: List[str]) -> None: """ Beta! Search for all unittest.TestCase classes that have tests defined, and import them as TestSlide contexts and examples. This is useful if you mix unittest.TestCase tests and TestSlide at the same file, or if you want to just use TestSlide's test runner for existing unittest.TestCase tests. """ global _unittest_testcase_loaded if _unittest_testcase_loaded: return _unittest_testcase_loaded = True for test_case in _get_all_test_cases(import_module_names): test_method_names = [ test_method_name for test_method_name in dir(test_case) if test_method_name.startswith("test") or test_method_name.startswith("ftest") or test_method_name.startswith("xtest") # FIXME: debug why ismethod is not properly filtering methods. Using # callabdle as a workaround. # if inspect.ismethod(getattr(test_case, test_method_name)) if callable(getattr(test_case, test_method_name)) ] if not test_method_names: continue # This extra method is needed so context_code is evaluated with different # values of test_case. def get_context_code( test_case: unittest.TestCase, ) -> Callable[[testslide.dsl._DSLContext], None]: def context_code(context: testslide.dsl._DSLContext) -> None: for test_method_name in test_method_names: @contextmanager def test_result() -> Iterator[_TestSlideTestResult]: result = _TestSlideTestResult() yield result result.aggregated_exceptions.raise_correct_exception() @contextmanager def setup_and_teardown() -> Iterator[None]: test_case.setUpClass() yield test_case.tearDownClass() # Same trick as above. def gen_example_code(test_method_name: str) -> Callable: def example_code(self: Any) -> None: with test_result() as result: with setup_and_teardown(): test_case(methodName=test_method_name)( # type: ignore result=result ) return example_code # Regular example if test_method_name.startswith("test"): context.example(test_method_name)( gen_example_code(test_method_name) ) # Focused example if test_method_name.startswith("ftest"): context.fexample(test_method_name)( gen_example_code(test_method_name) ) # Skipped example if test_method_name.startswith("xtest"): context.xexample(test_method_name)( gen_example_code(test_method_name) ) return context_code testslide.dsl.context("{}.{}".format(test_case.__module__, test_case.__name__))( # type: ignore get_context_code(test_case) ) @dataclass(frozen=True) class _Config: import_module_names: List[str] shuffle: bool list: bool quiet: bool fail_if_focused: bool fail_fast: bool focus: bool trim_path_prefix: str format: str seed: Optional[int] = None force_color: Optional[bool] = False show_testslide_stack_trace: Optional[bool] = False names_text_filter: Optional[str] = None names_regex_filter: Optional[Pattern[Any]] = None names_regex_exclude: Optional[Pattern[Any]] = None dsl_debug: Optional[bool] = False profile_threshold_ms: Optional[int] = None class Cli: FORMAT_NAME_TO_FORMATTER_CLASS = { "p": ProgressFormatter, "progress": ProgressFormatter, "d": DocumentFormatter, "documentation": DocumentFormatter, "l": LongFormatter, "long": LongFormatter, } @staticmethod def _regex_type(string: str) -> Pattern: return re.compile(string) def _build_parser(self, disable_test_files: bool) -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="TestSlide") parser.add_argument( "-f", "--format", choices=self.FORMAT_NAME_TO_FORMATTER_CLASS.keys(), default="documentation", help="Configure output format. Default: %(default)s", ) parser.add_argument( "--force-color", action="store_true", help="Force color output even without a terminal", ) parser.add_argument( "--shuffle", action="store_true", help="Randomize example execution order" ) parser.add_argument( "-l", "--list", action="store_true", help="List all tests one per line" ) parser.add_argument( "--seed", nargs=1, type=int, help="Positive number to seed shuffled examples", ) parser.add_argument( "--focus", action="store_true", help="Only executed focused examples, or all if none focused", ) parser.add_argument( "--fail-if-focused", action="store_true", help="Raise an error if an example is focused. Useful when running tests in a continuous integration environment.", ) parser.add_argument( "--fail-fast", action="store_true", help="Stop execution when an example fails", ) parser.add_argument( "--filter-text", nargs=1, type=str, help="Only execute examples that include given text in their names", ) parser.add_argument( "--filter-regex", nargs=1, type=self._regex_type, help="Only execute examples which match given regex", ) parser.add_argument( "--exclude-regex", nargs=1, type=self._regex_type, help="Exclude examples which match given regex from being executed", ) parser.add_argument( "--quiet", action="store_true", help="Suppress output (stdout and stderr) of tested code", ) parser.add_argument( "--dsl-debug", action="store_true", help=( "Print debugging information during execution of TestSlide's " "DSL tests." ), ) parser.add_argument( "--trim-path-prefix", nargs=1, type=str, default=[self._default_trim_path_prefix], help=( "Remove the specified prefix from paths in some of the output. " "Default: {}".format(repr(self._default_trim_path_prefix)) ), ) parser.add_argument( "--show-testslide-stack-trace", default=False, action="store_true", help=( "TestSlide's own code is trimmed from stack traces by default. " "This flags disables that, useful for TestSlide's own development." ), ) parser.add_argument( "--import-profiler", nargs=1, type=int, default=None, help=( "Print profiling information slow import time for modules that took " "more than the given number of ms to import. Experimental." ), ) if not disable_test_files: parser.add_argument( "test_files", nargs="+", type=str, default=[], help=( "List of file paths that contain either unittes.TestCase " "tests and/or TestSlide's DSL tests." ), ) return parser def __init__( self, args: Any, default_trim_path_prefix: Optional[str] = None, modules: Optional[List[str]] = None, ) -> None: self.args = args self._default_trim_path_prefix = ( default_trim_path_prefix if default_trim_path_prefix else os.getcwd() + os.sep ) self.parser = self._build_parser(disable_test_files=bool(modules)) self._modules = modules @staticmethod def _do_imports( import_module_names: List[str], profile_threshold_ms: Optional[int] = None ) -> float: def import_all() -> None: for module_name in import_module_names: __import__(module_name, level=0) if profile_threshold_ms is not None: from testslide.import_profiler import ImportProfiler with ImportProfiler() as import_profiler: start_time = time() import_all() end_time = time() import_profiler.print_stats(profile_threshold_ms) else: start_time = time() import_all() end_time = time() return end_time - start_time def _load_all_examples(self, import_module_names: List[str]) -> float: """ Import all required modules. """ import_secs = self._do_imports(import_module_names) _load_unittest_test_cases(import_module_names) return import_secs def _get_config_from_parsed_args(self, parsed_args: Any) -> _Config: config = _Config( format=parsed_args.format, force_color=parsed_args.force_color, trim_path_prefix=parsed_args.trim_path_prefix[0], show_testslide_stack_trace=parsed_args.show_testslide_stack_trace, profile_threshold_ms=parsed_args.import_profiler[0] if parsed_args.import_profiler else None, shuffle=parsed_args.shuffle, list=parsed_args.list, seed=parsed_args.seed[0] if parsed_args.seed else None, focus=parsed_args.focus, fail_if_focused=parsed_args.fail_if_focused, fail_fast=parsed_args.fail_fast, names_text_filter=parsed_args.filter_text[0] if parsed_args.filter_text else None, names_regex_filter=parsed_args.filter_regex[0] if parsed_args.filter_regex else None, names_regex_exclude=parsed_args.exclude_regex[0] if parsed_args.exclude_regex else None, quiet=parsed_args.quiet, dsl_debug=parsed_args.dsl_debug, import_module_names=self._modules if self._modules else [ _filename_to_module_name(test_file) for test_file in parsed_args.test_files ], ) return config def run(self) -> int: try: parsed_args = self.parser.parse_args(self.args) except SystemExit as e: return e.code config = self._get_config_from_parsed_args(parsed_args) if config.profile_threshold_ms is not None: import_secs = self._do_imports( config.import_module_names, config.profile_threshold_ms ) return 0 else: import_secs = self._load_all_examples(config.import_module_names) formatter = self.FORMAT_NAME_TO_FORMATTER_CLASS[config.format]( import_module_names=config.import_module_names, force_color=config.force_color, import_secs=import_secs, trim_path_prefix=config.trim_path_prefix, show_testslide_stack_trace=config.show_testslide_stack_trace, dsl_debug=config.dsl_debug, ) StrictMock.TRIM_PATH_PREFIX = config.trim_path_prefix if config.list: formatter.discovery_start() for context in Context.all_top_level_contexts: for example in context.all_examples: formatter.example_discovered(example) formatter.discovery_finish() return 0 else: return Runner( contexts=Context.all_top_level_contexts, formatter=formatter, shuffle=config.shuffle, seed=config.seed, focus=config.focus, fail_fast=config.fail_fast, fail_if_focused=config.fail_if_focused, names_text_filter=config.names_text_filter, names_regex_filter=config.names_regex_filter, names_regex_exclude=config.names_regex_exclude, quiet=config.quiet, ).run() def main() -> None: if "" not in sys.path: sys.path.insert(0, "") try: sys.exit(Cli(sys.argv[1:]).run()) except KeyboardInterrupt: print("SIGINT received, exiting.", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()