testslide/dsl.py (226 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import inspect
from functools import partial
from re import sub as _sub
from typing import Any, Callable, NoReturn, Optional, Union
from testslide import Context, TestCase
from . import Context as _Context
from . import Skip # noqa: F401
ExampleFunction = Callable[..., Any]
HaltingFunction = Callable[..., NoReturn]
def _validate_parameter(
code: ExampleFunction, name: str, index: int, allow_async: bool = True
) -> None:
parameters = list(inspect.signature(code).parameters.keys())
if not parameters or parameters[index] != name:
raise ValueError(
f"Function must receive parameter #{index+1} named "
f"'{name}', but given function has parameters: {parameters}."
)
if not allow_async and inspect.iscoroutinefunction(code):
raise RuntimeError(
f"TestSlide DSL context function `{code.__name__}` can not be async!"
)
def _require_context(action: str) -> Callable:
def wrapper(func: Callable) -> Callable:
@functools.wraps(func)
def func_with_context_validation(
self: "_DSLContext", *args: Any, **kwargs: Any
) -> None:
if not self.current_context:
raise TypeError("Can not {} without a parent context".format(action))
return func(self, *args, **kwargs)
return func_with_context_validation
return wrapper
class _DSLContext:
"""
This class implement TestSlide DSL. This is not intended to be used
directly.
"""
def __init__(
self,
current_context: Optional[Context] = None,
skip: bool = False,
focus: bool = False,
) -> None:
self.current_context = current_context
self.skip = skip
self.focus = focus
@staticmethod
def _not_callable(*args: Any, **kwargs: Any) -> NoReturn:
raise BaseException("This function should not be called outside test code.")
@staticmethod
def _name_from_function(func: ExampleFunction) -> str:
return _sub("_", " ", func.__name__)
def _create_context(
self, name: str, context_code: ExampleFunction, *args: Any, **kwargs: Any
) -> HaltingFunction:
if not self.current_context:
new_context = _Context(name, skip=self.skip, focus=self.focus)
else:
new_context = self.current_context.add_child_context(
name, skip=self.skip, focus=self.focus
)
_validate_parameter(context_code, "context", 0, allow_async=False)
context_code(
type(self)(current_context=new_context, skip=self.skip, focus=self.focus),
*args,
**kwargs,
)
return self._not_callable
def __call__(
self, arg: Union[str, ExampleFunction]
) -> Union[partial, HaltingFunction]:
if callable(arg):
context_code = arg
name = self._name_from_function(context_code)
return self._create_context(name, context_code)
else:
name = arg
return functools.partial(self._create_context, name)
def _reset(self) -> None:
self.skip = False
self.focus = False
# nested contexts
def sub_context(
self, arg: Union[str, ExampleFunction]
) -> Union[partial, HaltingFunction]:
self._reset()
return self(arg)
def xsub_context(self, arg: ExampleFunction) -> HaltingFunction:
self._reset()
self.skip = True
return self(arg)
def fsub_context(self, arg: ExampleFunction) -> HaltingFunction:
self._reset()
self.focus = True
return self(arg)
# Examples
@_require_context("create example")
def _create_example(
self,
name: Optional[str],
example_code: ExampleFunction,
skip: bool,
focus: bool,
) -> HaltingFunction:
if name is None:
name = self._name_from_function(example_code)
_validate_parameter(example_code, "self", 0)
self.current_context.add_example(name, example_code, skip=skip, focus=focus) # type: ignore
return self._not_callable
def example(
self,
arg: Optional[Union[str, ExampleFunction]] = None,
skip: bool = False,
focus: bool = False,
skip_unless: bool = True,
) -> Union[partial, HaltingFunction]:
skip = skip or not skip_unless
if callable(arg):
example_code = arg
name = self._name_from_function(example_code)
return self._create_example(name, example_code, skip=skip, focus=focus)
else:
name = arg # type: ignore
return functools.partial(self._create_example, name, skip=skip, focus=focus)
def xexample(self, arg: Union[str, ExampleFunction]) -> HaltingFunction:
return self.example(arg, skip=True)
def fexample(self, arg: Union[str, ExampleFunction]) -> HaltingFunction:
return self.example(arg, focus=True)
# Shared contexts
@_require_context("create a shared context")
def _create_shared_context(
self, name: str, shared_context_code: ExampleFunction
) -> HaltingFunction:
_validate_parameter(shared_context_code, "context", 0)
self.current_context.add_shared_context(name, shared_context_code) # type: ignore
return self._not_callable
def shared_context(
self, arg: Union[str, ExampleFunction]
) -> Union[partial, HaltingFunction]:
if callable(arg):
shared_context_code = arg
name = self._name_from_function(shared_context_code)
return self._create_shared_context(name, shared_context_code)
else:
name = arg
return functools.partial(self._create_shared_context, name)
@_require_context("merge a shared context")
def merge_context(self, name: str, *args: Any, **kwargs: Any) -> None:
if name not in self.current_context.all_shared_contexts: # type: ignore
raise TypeError('Shared context "{}" does not exist'.format(name))
self.current_context.all_shared_contexts[name](self, *args, **kwargs) # type: ignore
@_require_context("merge a TestCase")
def merge_test_case(self, test_case: "TestCase", attr_name: str) -> HaltingFunction:
self.current_context.add_test_case(test_case, attr_name) # type:ignore
return self._not_callable
@_require_context("nest a shared context")
def nest_context(self, name: str, *args: Any, **kwargs: Any) -> None:
if name not in self.current_context.all_shared_contexts: # type:ignore
raise TypeError('Shared context "{}" does not exist'.format(name))
self._create_context(
name, self.current_context.all_shared_contexts[name], *args, **kwargs # type: ignore
)
# Helper function
@_require_context("create functions")
def function(self, function_code: ExampleFunction) -> HaltingFunction:
_validate_parameter(function_code, "self", 0)
self.current_context.add_function(function_code.__name__, function_code) # type: ignore
return self._not_callable
# Memoizable attributes
@_require_context("create memoizable attributes")
def memoize(
self,
name_or_code: Optional[Union[str, ExampleFunction]] = None,
memoizable_code: Optional[ExampleFunction] = None,
**kwargs: Any,
) -> HaltingFunction:
_memoizable_code: ExampleFunction
if name_or_code:
if kwargs:
raise ValueError("Invalid arguments!")
if memoizable_code: # name + code
name = name_or_code
_memoizable_code = memoizable_code
else: # used as decorator
name = name_or_code.__name__ # type: ignore
_memoizable_code = name_or_code # type: ignore
_validate_parameter(_memoizable_code, "self", 0)
self.current_context.add_memoized_attribute(name, _memoizable_code) # type: ignore
else: # kwargs
if name_or_code or memoizable_code:
raise ValueError("Invalid arguments!")
for name, code in kwargs.items():
self.memoize(name, code)
return self._not_callable
@_require_context("create a memoize before attribute")
def memoize_before(
self,
name_or_code: Union[str, ExampleFunction],
memoizable_code: Optional[ExampleFunction] = None,
) -> HaltingFunction:
_memoizable_code: ExampleFunction
if memoizable_code: # Got a lambda
name = name_or_code
_memoizable_code = memoizable_code
else: # Got a function
name = name_or_code.__name__ # type: ignore
_memoizable_code = name_or_code # type: ignore
_validate_parameter(_memoizable_code, "self", 0)
self.current_context.add_memoized_attribute(name, _memoizable_code, before=True) # type: ignore
return self._not_callable
# Hooks
@_require_context("register before hook")
def before(self, before_code: ExampleFunction) -> HaltingFunction:
_validate_parameter(before_code, "self", 0)
if not self.current_context:
raise TypeError("Can not register before hook at top level context")
self.current_context.before_functions.append(before_code)
return self._not_callable
@_require_context("register after hook")
def after(self, after_code: ExampleFunction) -> HaltingFunction:
_validate_parameter(after_code, "self", 0)
self.current_context.after_functions.append(after_code) # type: ignore
return self._not_callable
@_require_context("register around hook")
def around(self, around_code: ExampleFunction) -> HaltingFunction:
_validate_parameter(around_code, "self", 0)
_validate_parameter(around_code, "wrapped", 1)
if not self.current_context:
raise TypeError("Can not register around hook at top level context")
self.current_context.around_functions.append(around_code)
return self._not_callable
context = _DSLContext()
xcontext = _DSLContext(skip=True)
fcontext = _DSLContext(focus=True)