testslide/__init__.py (684 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
if "COVERAGE_PROCESS_START" in os.environ:
import coverage
coverage.process_startup()
import asyncio
import asyncio.log
import contextlib
import inspect
import re
import sys
import types
import unittest
import warnings
from contextlib import contextmanager
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
List,
Optional,
TextIO,
Tuple,
Type,
Union,
)
import testslide.matchers
import testslide.mock_callable
import testslide.mock_constructor
import testslide.patch_attribute
from testslide.strict_mock import StrictMock # noqa
if TYPE_CHECKING:
# hack for Mypy
from testslide.runner import BaseFormatter
if sys.version_info < (3, 6):
raise RuntimeError("Python >=3.6 required.")
if sys.version_info < (3, 7):
def asyncio_run(coro):
loop = asyncio.events.new_event_loop()
try:
loop.set_debug(True)
loop.run_until_complete(coro)
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
else:
asyncio_run = partial(asyncio.run, debug=True)
if sys.version_info < (3, 8):
get_all_tasks = asyncio.Task.all_tasks
else:
get_all_tasks = asyncio.all_tasks
def get_active_tasks():
return [
task for task in get_all_tasks() if not task.done() and not task.cancelled()
]
class LeftOverActiveTasks(BaseException):
"""Risen when unfinished asynchronous tasks are detected."""
pass
def _importer(target: str) -> Any:
components = target.split(".")
import_path = components.pop(0)
thing = __import__(import_path)
def dot_lookup(thing: object, comp: str, import_path: str) -> Any:
try:
return getattr(thing, comp)
except AttributeError:
__import__(import_path)
return getattr(thing, comp)
for comp in components:
import_path += ".%s" % comp
thing = dot_lookup(thing, comp, import_path)
return thing
async def _async_ensure_no_leaked_tasks(coro):
before_example_tasks = get_active_tasks()
result = await coro
after_example_tasks = get_active_tasks()
new_still_running_tasks = set(after_example_tasks) - set(before_example_tasks)
if new_still_running_tasks:
tasks_str = "\n".join(str(task) for task in new_still_running_tasks)
raise LeftOverActiveTasks(
"Some tasks were started but did not finish yet, are you missing "
f"an `await` somewhere?\nRunning tasks:\n {tasks_str}"
)
return result
class _ContextData:
"""
To be used as a repository of context specific data, used during each
example execution.
"""
def _init_sub_example(self) -> None:
self._sub_examples_agg_ex = AggregatedExceptions()
def real_assert_sub_examples(self: "_ContextData") -> None:
if self._sub_examples_agg_ex.exceptions:
self._sub_examples_agg_ex.raise_correct_exception()
if self._example.is_async:
async def assert_sub_examples(self: "_ContextData") -> None:
real_assert_sub_examples(self)
else:
def assert_sub_examples(self: "_ContextData") -> None: # type: ignore
real_assert_sub_examples(self)
self.after(assert_sub_examples)
def _init_mocks(self) -> None:
self.mock_callable = testslide.mock_callable.mock_callable
self.mock_async_callable = testslide.mock_callable.mock_async_callable
self.mock_constructor = testslide.mock_constructor.mock_constructor
self.patch_attribute = testslide.patch_attribute.patch_attribute
self._mock_callable_after_functions: List[Callable] = []
def register_assertion(assertion: Callable) -> None:
if self._example.is_async:
async def f(_: _ContextData) -> None:
assertion()
else:
f = lambda _: assertion()
self._mock_callable_after_functions.append(f)
testslide.mock_callable.register_assertion = register_assertion
def __init__(self, example: "Example", formatter: "BaseFormatter") -> None:
self._example = example
self._formatter = formatter
self._context = example.context
self._after_functions: List[Callable] = []
self._test_case = unittest.TestCase()
self._init_sub_example()
self._init_mocks()
@staticmethod
def _not_callable(self: "_ContextData") -> None:
raise BaseException("This function should not be called outside test code.")
@property
def _all_methods(self) -> Dict[str, Callable]:
return self._context.all_context_data_methods
@property
def _all_memoizable_attributes(self) -> Dict[str, Callable]:
return self._context.all_context_data_memoizable_attributes
def __setattr__(self, name: str, value: Any) -> None:
if self.__dict__.get(name) and self.__dict__[name] != value:
raise AttributeError(
f"Attribute {repr(name)} can not be reset.\n"
"Resetting attribute values is not permitted as it can create "
"confusion and taint test signal.\n"
"You can use memoize/memoize_before instead, as they allow "
"attributes from parent contexs to be overridden consistently "
"by sub-contexts.\n"
"Details and examples at the documentation: "
"https://testslide.readthedocs.io/en/main/testslide_dsl/context_attributes_and_functions/index.html"
)
else:
super(_ContextData, self).__setattr__(name, value)
def __getattr__(self, name: str) -> Any:
if name in self._all_methods.keys():
def static(*args: Any, **kwargs: Any) -> Any:
return self._all_methods[name](self, *args, **kwargs)
self.__dict__[name] = static
if name in self._all_memoizable_attributes.keys():
attribute_code = self._all_memoizable_attributes[name]
if self._example.is_async and inspect.iscoroutinefunction(attribute_code):
raise ValueError(
f"Function can not be a coroutine function: {repr(attribute_code)}"
)
self._formatter.dsl_memoize(self._example, attribute_code)
self.__dict__[name] = attribute_code(self)
try:
return self.__dict__[name]
except KeyError:
# Forward assert* methods to unittest.TestCase
if re.match("^assert", name) and hasattr(self._test_case, name):
return getattr(self._test_case, name)
raise AttributeError(
"Context '{}' has no attribute '{}'".format(self._context, name)
)
def after(self, after_code: Callable) -> Callable:
"""
Use this to decorate a function to be registered to be executed after
the example code.
"""
self._after_functions.append(after_code)
return self._not_callable
@contextmanager
def sub_example(self, name: Optional[str] = None) -> Iterator[None]:
"""
Use this as a context manager many times inside the same
example. Failures in the code inside the context manager
will be aggregated, and reported individually at the end.
"""
with self._sub_examples_agg_ex.catch():
yield
def async_run_with_health_checks(self, coro):
"""
Runs the given coroutine in a new event loop, and ensuring there's no
task leakage.
"""
result = asyncio_run(_async_ensure_no_leaked_tasks(coro))
return result
class AggregatedExceptions(Exception):
"""
Aggregate example execution exceptions.
"""
def __init__(self) -> None:
super(AggregatedExceptions, self).__init__()
self.exceptions: List[BaseException] = []
def append_exception(self, exception: BaseException) -> None:
if isinstance(exception, AggregatedExceptions):
self.exceptions.extend(exception.exceptions)
else:
self.exceptions.append(exception)
@contextmanager
def catch(self) -> Iterator[None]:
try:
yield
except BaseException as exception:
self.append_exception(exception)
def __str__(self) -> str:
return "{} failures.\n".format(len(self.exceptions)) + "\n".join(
f"{type(e)}: {str(e)}" for e in self.exceptions
)
def raise_correct_exception(self) -> None:
if not self.exceptions:
return
ex_types = {type(ex) for ex in self.exceptions}
if Skip in ex_types or unittest.SkipTest in ex_types:
raise Skip()
elif len(self.exceptions) == 1:
raise self.exceptions[0]
else:
raise self
if len(self.exceptions) == 1:
raise self.exceptions[0]
else:
raise self
class Skip(Exception):
"""
Raised by an example when it is skipped
"""
pass
class UnexpectedSuccess(Exception):
"""
Raised by an example when it unexpectedly succeeded
"""
class SlowCallback(Exception):
"""
Raised by TestSlide when an asyncio slow callback warning is detected
"""
class _ExampleRunner:
def __init__(self, example: "Example", formatter: "BaseFormatter") -> None:
self.example = example
self.formatter = formatter
self.trim_path_prefix = self.formatter.trim_path_prefix
@staticmethod
async def _fail_if_not_coroutine_function(
func: Callable, *args: Any, **kwargs: Any
) -> None:
if not inspect.iscoroutinefunction(func):
raise ValueError(f"Function must be a coroutine function: {repr(func)}")
return await func(*args, **kwargs)
async def _real_async_run_all_hooks_and_example(
self,
context_data: _ContextData,
around_functions: Optional[List[Callable]] = None,
) -> None:
"""
***********************************************************************
***********************************************************************
WARNING
***********************************************************************
***********************************************************************
This function **MUST** be keep the exact same execution flow of
_sync_run_all_hooks_and_example()!!!
"""
if around_functions is None:
around_functions = list(reversed(self.example.context.all_around_functions))
if not around_functions:
aggregated_exceptions = AggregatedExceptions()
with aggregated_exceptions.catch():
for before_code in self.example.context.all_before_functions:
if hasattr(before_code, "_memoize_before_code"):
self.formatter.dsl_memoize_before(
self.example, before_code._memoize_before_code
)
else:
self.formatter.dsl_before(self.example, before_code)
await self._fail_if_not_coroutine_function(
before_code, context_data
)
self.formatter.dsl_example(self.example, self.example.code)
await _async_ensure_no_leaked_tasks(
self._fail_if_not_coroutine_function(
self.example.code, context_data
)
)
after_functions: List[Callable] = []
after_functions.extend(context_data._mock_callable_after_functions)
after_functions.extend(self.example.context.all_after_functions)
after_functions.extend(context_data._after_functions)
for after_code in reversed(after_functions):
with aggregated_exceptions.catch():
self.formatter.dsl_after(self.example, after_code)
await self._fail_if_not_coroutine_function(after_code, context_data)
aggregated_exceptions.raise_correct_exception()
return
around_code = around_functions.pop()
wrapped_called: List[bool] = []
async def async_wrapped() -> None:
wrapped_called.append(True)
await self._real_async_run_all_hooks_and_example(
context_data, around_functions
)
self.formatter.dsl_around(self.example, around_code)
await self._fail_if_not_coroutine_function(
around_code, context_data, async_wrapped
)
if not wrapped_called:
raise RuntimeError(
"Around hook "
+ repr(around_code.__name__)
+ " did not execute example code!"
)
@contextlib.contextmanager
def _raise_if_asyncio_warnings(self, context_data: _ContextData) -> Iterator[None]:
if sys.version_info < (3, 7):
yield
return
original_showwarning = warnings.showwarning
caught_failures: List[Union[Exception, str]] = []
def showwarning(
message: str,
category: Type[Warning],
filename: str,
lineno: int,
file: Optional[TextIO] = None,
line: Optional[str] = None,
) -> None:
failure_warning_messages: Dict[Any, str] = {
RuntimeWarning: "^coroutine '.+' was never awaited"
}
warning_class = type(message)
pattern = failure_warning_messages.get(warning_class, None)
if pattern and re.compile(pattern).match(str(message)):
caught_failures.append(message)
else:
original_showwarning(message, category, filename, lineno, file, line)
warnings.showwarning = showwarning # type: ignore
original_logger_warning = asyncio.log.logger.warning
def logger_warning(msg: str, *args: Any, **kwargs: Any) -> None:
if re.compile("^Executing .+ took .+ seconds$").match(str(msg)):
msg = (
f"{msg}\n"
"During the execution of the async test a slow callback "
"that blocked the event loop was detected.\n"
"Tip: you can customize the detection threshold with:\n"
" asyncio.get_running_loop().slow_callback_duration = seconds"
)
caught_failures.append(SlowCallback(msg % args))
else:
original_logger_warning(msg, *args, **kwargs)
asyncio.log.logger.warning = logger_warning # type: ignore
aggregated_exceptions = AggregatedExceptions()
try:
with aggregated_exceptions.catch():
yield
finally:
warnings.showwarning = original_showwarning
asyncio.log.logger.warning = original_logger_warning # type: ignore
for failure in caught_failures:
with aggregated_exceptions.catch():
raise failure # type: ignore
aggregated_exceptions.raise_correct_exception()
def _async_run_all_hooks_and_example(self, context_data: _ContextData) -> None:
coro = self._real_async_run_all_hooks_and_example(context_data)
with self._raise_if_asyncio_warnings(context_data):
asyncio_run(coro)
@staticmethod
def _fail_if_coroutine_function(
func: Callable, *args: Any, **kwargs: Any
) -> Optional[Any]:
if inspect.iscoroutinefunction(func):
raise ValueError(f"Function can not be a coroutine function: {repr(func)}")
return func(*args, **kwargs)
def _sync_run_all_hooks_and_example(
self,
context_data: _ContextData,
around_functions: Optional[List[Callable]] = None,
) -> None:
"""
***********************************************************************
***********************************************************************
WARNING
***********************************************************************
***********************************************************************
This function **MUST** be keep the exact same execution flow of
_real_async_run_all_hooks_and_example()!!!
"""
if around_functions is None:
around_functions = list(reversed(self.example.context.all_around_functions))
if not around_functions:
aggregated_exceptions = AggregatedExceptions()
with aggregated_exceptions.catch():
for before_code in self.example.context.all_before_functions:
if hasattr(before_code, "_memoize_before_code"):
self.formatter.dsl_memoize_before(
self.example, before_code._memoize_before_code
)
else:
self.formatter.dsl_before(self.example, before_code)
self._fail_if_coroutine_function(before_code, context_data)
self.formatter.dsl_example(self.example, self.example.code)
self._fail_if_coroutine_function(self.example.code, context_data)
after_functions: List[Callable] = []
after_functions.extend(context_data._mock_callable_after_functions)
after_functions.extend(self.example.context.all_after_functions)
after_functions.extend(context_data._after_functions)
for after_code in reversed(after_functions):
with aggregated_exceptions.catch():
self.formatter.dsl_after(self.example, after_code)
self._fail_if_coroutine_function(after_code, context_data)
aggregated_exceptions.raise_correct_exception()
return
around_code = around_functions.pop()
wrapped_called: List[bool] = []
def wrapped() -> None:
wrapped_called.append(True)
self._sync_run_all_hooks_and_example(context_data, around_functions)
self.formatter.dsl_around(self.example, around_code)
self._fail_if_coroutine_function(around_code, context_data, wrapped)
if not wrapped_called:
raise RuntimeError(
"Around hook "
+ repr(around_code.__name__)
+ " did not execute example code!"
)
def run(self) -> None:
try:
if self.example.skip:
raise Skip()
context_data = _ContextData(self.example, self.formatter)
if self.example.is_async:
self._async_run_all_hooks_and_example(context_data)
else:
self._sync_run_all_hooks_and_example(context_data)
finally:
sys.stdout.flush()
sys.stderr.flush()
testslide.mock_callable.unpatch_all_callable_mocks()
testslide.mock_constructor.unpatch_all_constructor_mocks()
testslide.patch_attribute.unpatch_all_mocked_attributes()
class Example:
"""
Individual example.
"""
def __init__(
self,
name: str,
code: Callable,
context: "Context",
skip: bool = False,
focus: bool = False,
) -> None:
self.name = name
self.code = code
self.is_async = inspect.iscoroutinefunction(self.code)
self.context = context
self.__dict__["skip"] = skip
self.__dict__["focus"] = focus
@property
def full_name(self) -> str:
return "{context_full_name}: {example_name}".format(
context_full_name=self.context.full_name, example_name=self.name
)
@property
def skip(self) -> bool:
"""
True if the example of its context is marked to be skipped.
"""
return any([self.context.skip, self.__dict__["skip"]])
@property
def focus(self) -> bool:
"""
True if the example of its context is marked to be focused.
"""
return any([self.context.focus, self.__dict__["focus"]])
def __str__(self) -> str:
return self.name
class _TestSlideTestResult(unittest.TestResult):
"""
Concrete unittest.TestResult to allow unttest.TestCase integration, by
aggregating failures at an AggregatedExceptions instance.
"""
def __init__(self) -> None:
super(_TestSlideTestResult, self).__init__()
self.aggregated_exceptions = AggregatedExceptions()
def _add_exception(
self,
err: Tuple[
Type[BaseException],
BaseException,
Optional[types.TracebackType],
],
) -> None:
exc_type, exc_value, exc_traceback = err
self.aggregated_exceptions.append_exception(exc_value)
def addError( # type:ignore
self,
test: "TestCase",
err: Tuple[
Type[BaseException],
BaseException,
types.TracebackType,
],
) -> None:
"""Called when an error has occurred. 'err' is a tuple of values as
returned by sys.exc_info().
"""
super(_TestSlideTestResult, self).addError(test, err) # type: ignore
self._add_exception(err)
def addFailure( # type:ignore
self,
test: "TestCase",
err: Tuple[
Type[BaseException],
BaseException,
types.TracebackType,
],
) -> None:
"""Called when an error has occurred. 'err' is a tuple of values as
returned by sys.exc_info()."""
super(_TestSlideTestResult, self).addFailure(test, err)
self._add_exception(err)
def addSkip(self, test: "TestCase", reason: str) -> None: # type: ignore
"""Called when the test case test is skipped. reason is the reason
the test gave for skipping."""
super(_TestSlideTestResult, self).addSkip(test, reason)
self._add_exception((type(Skip), Skip(), None)) # type: ignore
def addUnexpectedSuccess(self, test: "TestCase") -> None: # type: ignore
"""Called when the test case test was marked with the expectedFailure()
decorator, but succeeded."""
super(_TestSlideTestResult, self).addUnexpectedSuccess(test)
self._add_exception((type(UnexpectedSuccess), UnexpectedSuccess(), None)) # type: ignore
def addSubTest(self, test: "TestCase", subtest: "TestCase", err: Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[types.TracebackType]]) -> None: # type: ignore
"""Called at the end of a subtest.
'err' is None if the subtest ended successfully, otherwise it's a
tuple of values as returned by sys.exc_info().
"""
super(_TestSlideTestResult, self).addSubTest(test, subtest, err) # type: ignore
if err:
self._add_exception(err) # type: ignore
class Context:
"""
Container for example contexts.
"""
_SAME_CONTEXT_NAME_ERROR = "A context with the same name is already defined"
# List of all top level contexts created
all_top_level_contexts: List["Context"] = []
# Constructor
def __init__(
self,
name: str,
parent_context: Optional["Context"] = None,
shared: bool = False,
skip: bool = False,
focus: bool = False,
) -> None:
"""
Creates a new context.
"""
# Validate context name
if parent_context:
current_level_contexts = parent_context.children_contexts
else:
current_level_contexts = self.all_top_level_contexts
if name in [context.name for context in current_level_contexts]:
raise RuntimeError(self._SAME_CONTEXT_NAME_ERROR)
self.name: str = name
self.parent_context = parent_context
self.shared = shared
self.__dict__["skip"] = skip
self.__dict__["focus"] = focus
self.children_contexts: List["Context"] = []
self.examples: List["Example"] = []
self.before_functions: List[Callable] = []
self.after_functions: List[Callable] = []
self.around_functions: List[Callable] = []
self.context_data_methods: Dict[str, Callable] = {}
self.context_data_memoizable_attributes: Dict[str, Callable] = {}
self.shared_contexts: Dict[str, "Context"] = {}
if not self.parent_context and not self.shared:
self.all_top_level_contexts.append(self)
# Properties
@property
def parent_contexts(self) -> List["Context"]:
"""
Returns a list of all parent contexts, from bottom to top.
"""
final_list = []
parent = self.parent_context
while parent:
final_list.append(parent)
parent = parent.parent_context
return final_list
@property
def depth(self) -> int:
"""
Number of parent contexts this context has.
"""
return len(self.parent_contexts)
def _all_parents_as_dict(original: type) -> Callable[["Context"], Dict[str, Any]]: # type: ignore # noqa: B902
"""
Use as a decorator for empty functions named all_attribute_name, to make
them return a dict with self.parent_context.all_attribute_name and
self.attribute_name.
"""
def get_all(self: "Context") -> Dict[str, Any]:
final_dict: Dict[str, Any] = {}
if self.parent_context:
final_dict.update(getattr(self.parent_context, original.__name__))
final_dict.update(getattr(self, original.__name__.split("all_")[1]))
return final_dict
return get_all
def _all_parents_as_list(original: type) -> Callable[["Context"], List[Any]]: # type: ignore # noqa: B902
"""
Use as a decorator for empty functions named all_attribute_name, to make
them return a list with self.parent_context.all_attribute_name and
self.attribute_name.
"""
def get_all(self: "Context") -> List[Any]:
final_list: List[str] = []
if self.parent_context:
final_list.extend(getattr(self.parent_context, original.__name__))
final_list.extend(getattr(self, original.__name__.split("all_")[1]))
return final_list
return get_all
@property # type: ignore
@_all_parents_as_dict
def all_context_data_methods(self) -> None:
"""
Returns a combined dict of all context_data_methods, including from
parent contexts.
"""
pass
@property # type: ignore
@_all_parents_as_dict
def all_context_data_memoizable_attributes(self) -> None:
"""
Returns a combined dict of all context_data_memoizable_attributes,
including from parent contexts.
"""
pass
@property # type: ignore
@_all_parents_as_list
def all_around_functions(self) -> None:
"""
Return a list of all around_functions, including from parent contexts.
"""
pass
@property # type: ignore
@_all_parents_as_list
def all_before_functions(self) -> None:
"""
Return a list of all before_functions, including from parent contexts.
"""
pass
@property # type: ignore
@_all_parents_as_list
def all_after_functions(self) -> None:
"""
Return a list of all after_functions, including from parent contexts.
"""
pass
@property # type: ignore
@_all_parents_as_dict
def all_shared_contexts(self) -> None:
"""
Returns a combined dict of all shared_contexts, including from parent
contexts.
"""
pass
@property
def all_examples(self) -> List[Example]:
"""
List of of all examples in this context and nested contexts.
"""
final_list = []
final_list.extend(self.examples)
for child_context in self.children_contexts:
final_list.extend(child_context.all_examples)
return final_list
@property
def hierarchy(self) -> List["Context"]:
"""
Returns a list of all contexts in this hierarchy.
"""
return [context for context in list(reversed(self.parent_contexts)) + [self]]
@property
def full_name(self) -> str:
"""
Full context name, including parent contexts.
"""
return ", ".join(str(context) for context in self.hierarchy)
@property
def skip(self) -> bool:
"""
True if this context of any parent context are tagged to be skipped.
"""
return any(context.__dict__["skip"] for context in self.hierarchy)
@property
def focus(self) -> bool:
"""
True if this context of any parent context are tagged to be focused.
"""
return any(context.__dict__["focus"] for context in self.hierarchy)
def __str__(self) -> str:
return self.name
def add_child_context(
self, name: str, skip: bool = False, focus: bool = False
) -> "Context":
"""
Creates a nested context below self.
"""
if name in [context.name for context in self.children_contexts]:
raise RuntimeError(self._SAME_CONTEXT_NAME_ERROR)
child_context = Context(name, parent_context=self, skip=skip, focus=focus)
self.children_contexts.append(child_context)
return child_context
def add_example(
self, name: str, example_code: Callable, skip: bool = False, focus: bool = False
) -> Example:
"""
Add an example to this context.
"""
if name in [example.name for example in self.examples]:
raise RuntimeError(
f"An example with the same name '{name}' is already defined"
)
self.examples.append(
Example(name, code=example_code, context=self, skip=skip, focus=focus)
)
return self.examples[-1]
def has_attribute(self, name: str) -> bool:
return any(
[
name in self.context_data_methods.keys(),
name in self.context_data_memoizable_attributes.keys(),
]
)
def add_function(self, name: str, function_code: Callable) -> None:
"""
Add given function to example execution scope.
"""
if self.has_attribute(name):
raise AttributeError(
'Attribute "{}" already set for context "{}"'.format(name, self)
)
self.context_data_methods[name] = function_code
def add_memoized_attribute(
self, name: str, memoizable_code: Callable, before: bool = False
) -> None:
"""
Add given attribute name to execution scope, by lazily memoizing the return
value of memoizable_code().
"""
if self.has_attribute(name):
raise AttributeError(
'Attribute "{}" already set for context "{}"'.format(name, self)
)
self.context_data_memoizable_attributes[name] = memoizable_code
if before:
if inspect.iscoroutinefunction(memoizable_code):
async def async_materialize_attribute(
context_data: _ContextData,
) -> None:
code = context_data._context.all_context_data_memoizable_attributes[
name
]
context_data.__dict__[name] = await code(context_data)
async_materialize_attribute._memoize_before_code = memoizable_code # type: ignore
self.before_functions.append(async_materialize_attribute)
else:
def materialize_attribute(context_data: _ContextData) -> None:
code = context_data._context.all_context_data_memoizable_attributes[
name
]
context_data.__dict__[name] = code(context_data)
materialize_attribute._memoize_before_code = memoizable_code # type: ignore
self.before_functions.append(materialize_attribute)
def add_shared_context(self, name: str, shared_context_code: "Context") -> None:
"""
Create a shared context.
"""
if name in self.shared_contexts:
raise RuntimeError("A shared context with the same name is already defined")
self.shared_contexts[name] = shared_context_code
def add_test_case(self, test_case: Type["TestCase"], attr_name: str) -> None:
"""
Add around hooks to context from given unittest.TestCase class. Only
hooks such as setUp or tearDown will be called, no tests will be
included.
"""
def wrap_test_case(self: "Context", example: Callable) -> None:
def test_test_slide(_: Any) -> None:
example()
def exec_body(ns: Dict[str, Callable]) -> None:
ns.update({"test_test_slide": test_test_slide})
# Build a child class of given TestCase, with a defined test that
# will run TestSlide example.
test_slide_test_case = types.new_class(
"TestSlideTestCase", bases=(test_case,), exec_body=exec_body
)
# This suite will only contain TestSlide's example test.
test_suite = unittest.TestLoader().loadTestsFromName(
"test_test_slide", test_slide_test_case # type: ignore
)
setattr(self, attr_name, list(test_suite)[0])
result = _TestSlideTestResult()
test_suite(result=result) # type: ignore
if not result.wasSuccessful():
result.aggregated_exceptions.raise_correct_exception()
self.around_functions.append(wrap_test_case)
def reset() -> None:
"""
Clear all defined contexts and hooks.
"""
Context.all_top_level_contexts.clear()
class TestCase(unittest.TestCase):
"""
A subclass of unittest.TestCase that adds TestSlide's features.
"""
def setUp(self) -> None:
testslide.mock_callable.register_assertion = lambda assertion: self.addCleanup(
assertion
)
self.addCleanup(testslide.mock_callable.unpatch_all_callable_mocks)
self.addCleanup(testslide.mock_constructor.unpatch_all_constructor_mocks)
self.addCleanup(testslide.patch_attribute.unpatch_all_mocked_attributes)
super(TestCase, self).setUp()
@staticmethod
def mock_callable(
*args: Any, **kwargs: Any
) -> testslide.mock_callable._MockCallableDSL:
return testslide.mock_callable.mock_callable(*args, **kwargs)
@staticmethod
def mock_async_callable(
*args: Any, **kwargs: Any
) -> testslide.mock_callable._MockCallableDSL:
return testslide.mock_callable.mock_async_callable(*args, **kwargs)
@staticmethod
def mock_constructor(
*args: Any, **kwargs: Any
) -> testslide.mock_constructor._MockConstructorDSL:
return testslide.mock_constructor.mock_constructor(*args, **kwargs)
@staticmethod
def patch_attribute(*args: Any, **kwargs: Any) -> None:
return testslide.patch_attribute.patch_attribute(*args, **kwargs)