Source code for asynctest._fail_on

# coding: utf-8
"""
:class:`asynctest.TestCase` decorator which controls checks performed after
tests.

This module is separated from :mod:`asynctest.case` to avoid circular imports
in modules registering new checks.

To implement new checks:

    * its name must be added in the ``DEFAULTS`` dict,

    * a static method of the same name must be added to the :class:`_fail_on`
      class,

    * an optional static method named ``before_[name of the check]`` can be
      added to :class:`_fail_on` to implement some set-up before the test runs.

A check may be only available on some platforms, activated by a conditional
import. In this case, ``DEFAULT`` and :class:`_fail_on` can be updated in the
module. There is an example in the :mod:`asynctest.selector` module.
"""
from asyncio import TimerHandle


_FAIL_ON_ATTR = "_asynctest_fail_on"


#: Default value of the arguments of @fail_on, the name of the argument matches
#: the name of the static method performing the check in the :class:`_fail_on`.
#: The value is True when the check is enabled by default, False otherwise.
DEFAULTS = {
    "unused_loop": False,
    "active_handles": False,
}


class _fail_on:
    def __init__(self, checks=None):
        self.checks = checks or {}
        self._computed_checks = None

    def __call__(self, func):
        checker = getattr(func, _FAIL_ON_ATTR, None)
        if checker:
            checker = checker.copy()
            checker.update(self.checks)
        else:
            checker = self.copy()

        setattr(func, _FAIL_ON_ATTR, checker)
        return func

    def update(self, checks, override=True):
        if override:
            self.checks.update(checks)
        else:
            for check, value in checks.items():
                self.checks.setdefault(check, value)

    def copy(self):
        return _fail_on(self.checks.copy())

    def get_checks(self, case):
        # cache the result so it's consistent across calls to get_checks()
        if self._computed_checks is None:
            checks = DEFAULTS.copy()

            try:
                checks.update(getattr(case, _FAIL_ON_ATTR, None).checks)
            except AttributeError:
                pass

            checks.update(self.checks)
            self._computed_checks = checks

        return self._computed_checks

    def before_test(self, case):
        checks = self.get_checks(case)
        for check in filter(checks.get, checks):
            try:
                getattr(self, "before_test_" + check)(case)
            except (AttributeError, TypeError):
                pass

    def check_test(self, case):
        checks = self.get_checks(case)
        for check in filter(checks.get, checks):
            getattr(self, check)(case)

    # checks

    @staticmethod
    def unused_loop(case):
        if not case.loop._asynctest_ran:
            case.fail("Loop did not run during the test")

    @staticmethod
    def _is_live_timer_handle(handle):
        return isinstance(handle, TimerHandle) and not handle._cancelled

    @classmethod
    def _live_timer_handles(cls, loop):
        return filter(cls._is_live_timer_handle, loop._scheduled)

    @classmethod
    def active_handles(cls, case):
        handles = tuple(cls._live_timer_handles(case.loop))
        if handles:
            case.fail("Loop contained unfinished work {!r}".format(handles))


[docs]def fail_on(**kwargs): """ Enable checks on the loop state after a test ran to help testers to identify common mistakes. """ # documented in asynctest.case.rst for kwarg in kwargs: if kwarg not in DEFAULTS: raise TypeError("fail_on() got an unexpected keyword argument " "'{}'".format(kwarg)) return _fail_on(kwargs)
def _fail_on_all(flag, func): checker = _fail_on(dict((arg, flag) for arg in DEFAULTS)) return checker if func is None else checker(func)
[docs]def strict(func=None): """ Activate strict checking of the state of the loop after a test ran. """ # documented in asynctest.case.rst return _fail_on_all(True, func)
[docs]def lenient(func=None): """ Deactivate all checks after a test ran. """ # documented in asynctest.case.rst return _fail_on_all(False, func)