testslide/dsl.py (226 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import functools import inspect from functools import partial from re import sub as _sub from typing import Any, Callable, NoReturn, Optional, Union from testslide import Context, TestCase from . import Context as _Context from . import Skip # noqa: F401 ExampleFunction = Callable[..., Any] HaltingFunction = Callable[..., NoReturn] def _validate_parameter( code: ExampleFunction, name: str, index: int, allow_async: bool = True ) -> None: parameters = list(inspect.signature(code).parameters.keys()) if not parameters or parameters[index] != name: raise ValueError( f"Function must receive parameter #{index+1} named " f"'{name}', but given function has parameters: {parameters}." ) if not allow_async and inspect.iscoroutinefunction(code): raise RuntimeError( f"TestSlide DSL context function `{code.__name__}` can not be async!" ) def _require_context(action: str) -> Callable: def wrapper(func: Callable) -> Callable: @functools.wraps(func) def func_with_context_validation( self: "_DSLContext", *args: Any, **kwargs: Any ) -> None: if not self.current_context: raise TypeError("Can not {} without a parent context".format(action)) return func(self, *args, **kwargs) return func_with_context_validation return wrapper class _DSLContext: """ This class implement TestSlide DSL. This is not intended to be used directly. """ def __init__( self, current_context: Optional[Context] = None, skip: bool = False, focus: bool = False, ) -> None: self.current_context = current_context self.skip = skip self.focus = focus @staticmethod def _not_callable(*args: Any, **kwargs: Any) -> NoReturn: raise BaseException("This function should not be called outside test code.") @staticmethod def _name_from_function(func: ExampleFunction) -> str: return _sub("_", " ", func.__name__) def _create_context( self, name: str, context_code: ExampleFunction, *args: Any, **kwargs: Any ) -> HaltingFunction: if not self.current_context: new_context = _Context(name, skip=self.skip, focus=self.focus) else: new_context = self.current_context.add_child_context( name, skip=self.skip, focus=self.focus ) _validate_parameter(context_code, "context", 0, allow_async=False) context_code( type(self)(current_context=new_context, skip=self.skip, focus=self.focus), *args, **kwargs, ) return self._not_callable def __call__( self, arg: Union[str, ExampleFunction] ) -> Union[partial, HaltingFunction]: if callable(arg): context_code = arg name = self._name_from_function(context_code) return self._create_context(name, context_code) else: name = arg return functools.partial(self._create_context, name) def _reset(self) -> None: self.skip = False self.focus = False # nested contexts def sub_context( self, arg: Union[str, ExampleFunction] ) -> Union[partial, HaltingFunction]: self._reset() return self(arg) def xsub_context(self, arg: ExampleFunction) -> HaltingFunction: self._reset() self.skip = True return self(arg) def fsub_context(self, arg: ExampleFunction) -> HaltingFunction: self._reset() self.focus = True return self(arg) # Examples @_require_context("create example") def _create_example( self, name: Optional[str], example_code: ExampleFunction, skip: bool, focus: bool, ) -> HaltingFunction: if name is None: name = self._name_from_function(example_code) _validate_parameter(example_code, "self", 0) self.current_context.add_example(name, example_code, skip=skip, focus=focus) # type: ignore return self._not_callable def example( self, arg: Optional[Union[str, ExampleFunction]] = None, skip: bool = False, focus: bool = False, skip_unless: bool = True, ) -> Union[partial, HaltingFunction]: skip = skip or not skip_unless if callable(arg): example_code = arg name = self._name_from_function(example_code) return self._create_example(name, example_code, skip=skip, focus=focus) else: name = arg # type: ignore return functools.partial(self._create_example, name, skip=skip, focus=focus) def xexample(self, arg: Union[str, ExampleFunction]) -> HaltingFunction: return self.example(arg, skip=True) def fexample(self, arg: Union[str, ExampleFunction]) -> HaltingFunction: return self.example(arg, focus=True) # Shared contexts @_require_context("create a shared context") def _create_shared_context( self, name: str, shared_context_code: ExampleFunction ) -> HaltingFunction: _validate_parameter(shared_context_code, "context", 0) self.current_context.add_shared_context(name, shared_context_code) # type: ignore return self._not_callable def shared_context( self, arg: Union[str, ExampleFunction] ) -> Union[partial, HaltingFunction]: if callable(arg): shared_context_code = arg name = self._name_from_function(shared_context_code) return self._create_shared_context(name, shared_context_code) else: name = arg return functools.partial(self._create_shared_context, name) @_require_context("merge a shared context") def merge_context(self, name: str, *args: Any, **kwargs: Any) -> None: if name not in self.current_context.all_shared_contexts: # type: ignore raise TypeError('Shared context "{}" does not exist'.format(name)) self.current_context.all_shared_contexts[name](self, *args, **kwargs) # type: ignore @_require_context("merge a TestCase") def merge_test_case(self, test_case: "TestCase", attr_name: str) -> HaltingFunction: self.current_context.add_test_case(test_case, attr_name) # type:ignore return self._not_callable @_require_context("nest a shared context") def nest_context(self, name: str, *args: Any, **kwargs: Any) -> None: if name not in self.current_context.all_shared_contexts: # type:ignore raise TypeError('Shared context "{}" does not exist'.format(name)) self._create_context( name, self.current_context.all_shared_contexts[name], *args, **kwargs # type: ignore ) # Helper function @_require_context("create functions") def function(self, function_code: ExampleFunction) -> HaltingFunction: _validate_parameter(function_code, "self", 0) self.current_context.add_function(function_code.__name__, function_code) # type: ignore return self._not_callable # Memoizable attributes @_require_context("create memoizable attributes") def memoize( self, name_or_code: Optional[Union[str, ExampleFunction]] = None, memoizable_code: Optional[ExampleFunction] = None, **kwargs: Any, ) -> HaltingFunction: _memoizable_code: ExampleFunction if name_or_code: if kwargs: raise ValueError("Invalid arguments!") if memoizable_code: # name + code name = name_or_code _memoizable_code = memoizable_code else: # used as decorator name = name_or_code.__name__ # type: ignore _memoizable_code = name_or_code # type: ignore _validate_parameter(_memoizable_code, "self", 0) self.current_context.add_memoized_attribute(name, _memoizable_code) # type: ignore else: # kwargs if name_or_code or memoizable_code: raise ValueError("Invalid arguments!") for name, code in kwargs.items(): self.memoize(name, code) return self._not_callable @_require_context("create a memoize before attribute") def memoize_before( self, name_or_code: Union[str, ExampleFunction], memoizable_code: Optional[ExampleFunction] = None, ) -> HaltingFunction: _memoizable_code: ExampleFunction if memoizable_code: # Got a lambda name = name_or_code _memoizable_code = memoizable_code else: # Got a function name = name_or_code.__name__ # type: ignore _memoizable_code = name_or_code # type: ignore _validate_parameter(_memoizable_code, "self", 0) self.current_context.add_memoized_attribute(name, _memoizable_code, before=True) # type: ignore return self._not_callable # Hooks @_require_context("register before hook") def before(self, before_code: ExampleFunction) -> HaltingFunction: _validate_parameter(before_code, "self", 0) if not self.current_context: raise TypeError("Can not register before hook at top level context") self.current_context.before_functions.append(before_code) return self._not_callable @_require_context("register after hook") def after(self, after_code: ExampleFunction) -> HaltingFunction: _validate_parameter(after_code, "self", 0) self.current_context.after_functions.append(after_code) # type: ignore return self._not_callable @_require_context("register around hook") def around(self, around_code: ExampleFunction) -> HaltingFunction: _validate_parameter(around_code, "self", 0) _validate_parameter(around_code, "wrapped", 1) if not self.current_context: raise TypeError("Can not register around hook at top level context") self.current_context.around_functions.append(around_code) return self._not_callable context = _DSLContext() xcontext = _DSLContext(skip=True) fcontext = _DSLContext(focus=True)