tutorial/mocking.pyΒΆ

# coding: utf-8
# pylama: ignore=C0103, ignore camel case variable name (AsyncClientMock)

import asyncio
import collections
import itertools
import logging

import asynctest


class Client:
    def add_user(self, user):
        raise NotImplementedError

    def get_users(self):
        raise NotImplementedError

    def increase_nb_users_cached(self, nb_cached):
        raise NotImplementedError


class AsyncClient:
    async def add_user(self, user, transaction=None):
        raise NotImplementedError

    async def get_users(self, transaction=None):
        raise NotImplementedError

    async def increase_nb_users_cached(self, nb_cached, transaction=None):
        raise NotImplementedError

    def get_users_cursor(self, transaction=None):
        return self.Cursor(transaction or self)

    def new_transaction(self):
        return self.Transaction(self)

    class Transaction:
        def __init__(self, client):
            self.client = client

        def __call__(self, funcname, *args, **kwargs):
            """
            Forwards the call to the client, with the argument ``transaction ``
            set.
            """
            method = getattr(self.client, funcname)
            return method(*args, transaction=self, **kwargs)

        async def __aenter__(self):
            return self

        async def __aexit__(self, *args):
            pass

    class Cursor:
        def __init__(self, transaction):
            self.transaction = transaction

        def __aiter__(self):
            return self

        async def __anext__(self):
            # if the request has not been started, do it there
            raise NotImplementedError


def cache_users(client, cache):
    """
    Load the list of users from a distant server accessed with ``client``,
    add them to ``cache``.

    Notify the server about the number of new users put in the cache, and
    returns this number.

    :param client: a connection to the distant server
    :param cache: a dict-like object
    """
    users = client.get_users()

    nb_users_cached = 0

    for user in users:
        if user.id not in cache:
            nb_users_cached += 1
            cache[user.id] = user

    client.increase_nb_users_cached(nb_users_cached)

    logging.debug("added %d users to the cache %r", nb_users_cached, cache)

    return nb_users_cached


class StubClient:
    User = collections.namedtuple("User", "id username")

    def __init__(self, *users_to_return):
        self.users_to_return = []
        self.users_to_return.extend(users_to_return)

        self.nb_users_cached = 0

    def add_user(self, user):
        self.users_to_return.append(user)

    def get_users(self):
        return self.users_to_return

    def increase_nb_users_cached(self, nb_cached):
        self.nb_users_cached += nb_cached


class TestUsingStub(asynctest.TestCase):
    def test_one_user_added_to_cache(self):
        user = StubClient.User(1, "a.dmin")
        client = StubClient(user)
        cache = {}

        # The user has been added to the cache
        nb_added = cache_users(client, cache)

        self.assertEqual(nb_added, 1)
        self.assertEqual(cache[1], user)

        # The user was already there
        nb_added = cache_users(client, cache)
        self.assertEqual(nb_added, 0)
        self.assertEqual(cache[1], user)

    def test_no_users_to_add(self):
        cache = {}
        nb_added = cache_users(StubClient(), cache)

        self.assertEqual(nb_added, 0)
        self.assertEqual(len(cache), 0)


class TestUsingMock(asynctest.TestCase):
    def test_no_users_to_add(self):
        client = asynctest.Mock(Client())
        client.get_users.return_value = []
        cache = {}

        nb_added = cache_users(client, cache)

        client.get_users.assert_called()
        self.assertEqual(nb_added, 0)
        self.assertEqual(len(cache), 0)

        client.increase_nb_users_cached.assert_called_once_with(0)


async def cache_users_async(client, cache):
    users = await client.get_users()

    nb_users_cached = 0

    for user in users:
        if user.id not in cache:
            nb_users_cached += 1
            cache[user.id] = user

    await client.increase_nb_users_cached(nb_users_cached)

    logging.debug("added %d users to the cache %r", nb_users_cached, cache)

    return nb_users_cached


class TestUsingFuture(asynctest.TestCase):
    async def test_no_users_to_add(self):
        client = asynctest.Mock(Client())

        client.get_users.return_value = asyncio.Future()
        client.get_users.return_value.set_result([])

        client.increase_nb_users_cached.return_value = asyncio.Future()
        client.increase_nb_users_cached.return_value.set_result(None)

        cache = {}

        nb_added = await cache_users_async(client, cache)

        client.get_users.assert_called()
        self.assertEqual(nb_added, 0)
        self.assertEqual(len(cache), 0)

        client.increase_nb_users_cached.assert_called_once_with(0)


class TestUsingCoroutineMock(asynctest.TestCase):
    async def test_no_users_to_add(self):
        client = asynctest.Mock(Client())
        client.get_users = asynctest.CoroutineMock(return_value=[])
        client.increase_nb_users_cached = asynctest.CoroutineMock()
        cache = {}

        nb_added = await cache_users_async(client, cache)

        client.get_users.assert_awaited()
        self.assertEqual(nb_added, 0)
        self.assertEqual(len(cache), 0)

        client.increase_nb_users_cached.assert_awaited_once_with(0)


class TestUsingCoroutineMockAndSpec(asynctest.TestCase):
    async def test_no_users_to_add(self):
        client = asynctest.Mock(AsyncClient())
        client.get_users.return_value = []
        cache = {}

        nb_added = await cache_users_async(client, cache)

        client.get_users.assert_awaited()
        self.assertEqual(nb_added, 0)
        self.assertEqual(len(cache), 0)

        client.increase_nb_users_cached.assert_awaited_once_with(0)


class TestAutoSpec(asynctest.TestCase):
    async def test_functions_and_coroutines_arguments_are_checked(self):
        client = asynctest.Mock(Client())
        cache = {}

        cache_users_mock = asynctest.create_autospec(cache_users_async)

        with self.subTest("create_autospec returns a regular mock"):
            await cache_users_mock(client, cache)
            cache_users_mock.assert_awaited_once_with(client, cache)

        with self.subTest("an exception is raised when the mock is called "
                          "with the wrong number of arguments"):
            with self.assertRaises(TypeError):
                await cache_users_mock("wrong", "number", "of", "args")

    async def test_create_autospec_on_a_class(self):
        AsyncClientMock = asynctest.create_autospec(AsyncClient)
        client = AsyncClientMock()

        with self.subTest("the mock of a class returns a mock instance of "
                          "the class"):
            self.assertIsInstance(client, AsyncClient)

        with self.subTest("attributes of the mock instance are correctly "
                          "mocked as coroutines"):
            await client.increase_nb_users_cached(1)


class TestCoroutineMockResult(asynctest.TestCase):
    async def test_result_set_with_return_value(self):
        coroutine_mock = asynctest.CoroutineMock()
        result = object()
        coroutine_mock.return_value = result

        # return the expected result
        self.assertIs(result, await coroutine_mock())
        # always return the same result
        self.assertIs(await coroutine_mock(), await coroutine_mock())

    async def test_result_with_side_effect_function(self):
        def uppercase_all(*args):
            return tuple(arg.upper() for arg in args)

        coroutine_mock = asynctest.CoroutineMock()
        coroutine_mock.side_effect = uppercase_all

        self.assertEqual(("FIRST", "CALL"),
                         await coroutine_mock("first", "call"))
        self.assertEqual(("A", "SECOND", "CALL"),
                         await coroutine_mock("a", "second", "call"))

    async def test_result_with_side_effect_exception(self):
        coroutine_mock = asynctest.CoroutineMock()
        coroutine_mock.side_effect = NotImplementedError

        # Raise an exception of the configured type
        with self.assertRaises(NotImplementedError):
            await coroutine_mock("any", "number", "of", "args")

        coroutine_mock.side_effect = Exception("an instance of exception")

        # Raise the exact specified object
        with self.assertRaises(Exception) as context:
            await coroutine_mock()

        self.assertIs(coroutine_mock.side_effect, context.exception)

    async def test_result_with_side_effect_iterable(self):
        coroutine_mock = asynctest.CoroutineMock()
        coroutine_mock.side_effect = ["one", "two", "three"]

        self.assertEqual("one", await coroutine_mock())
        self.assertEqual("two", await coroutine_mock())
        self.assertEqual("three", await coroutine_mock())

        coroutine_mock.side_effect = itertools.cycle(["odd", "even"])
        self.assertEqual("odd", await coroutine_mock())
        self.assertEqual("even", await coroutine_mock())
        self.assertEqual("odd", await coroutine_mock())
        self.assertEqual("even", await coroutine_mock())

    async def test_result_with_wrapped_object(self):
        stub = StubClient()
        mock = asynctest.Mock(stub, wraps=stub)
        cache = {}

        stub.add_user(StubClient.User(1, "a.dmin"))
        cache_users(mock, cache)

        mock.get_users.assert_called()
        self.assertEqual(stub.users_to_return, mock.get_users())


async def cache_users_with_cursor(client, cache):
    nb_users_cached = 0

    async with client.new_transaction() as transaction:
        users_cursor = transaction.get_users_cursor()

        async for user in users_cursor:
            if user.id not in cache:
                nb_users_cached += 1
                cache[user.id] = user

        await transaction.increase_nb_users_cached(nb_users_cached)

    logging.debug("added %d users to the cache %r", nb_users_cached, cache)

    return nb_users_cached


class TestWithMagicMethods(asynctest.TestCase):
    async def test_context_manager(self):
        with self.assertRaises(AssertionError):
            async with asynctest.MagicMock() as context:
                # context is a MagicMock
                context.assert_called()

    async def test_empty_iterable(self):
        loop_iterations = 0
        async for _ in asynctest.MagicMock():
            loop_iterations += 1

        self.assertEqual(0, loop_iterations)

    async def test_iterable(self):
        loop_iterations = 0
        mock = asynctest.MagicMock()
        mock.__aiter__.return_value = range(5)
        async for _ in mock:
            loop_iterations += 1

        self.assertEqual(5, loop_iterations)


class TestCacheWithMagicMethods(asynctest.TestCase):
    async def test_one_user_added_to_cache(self):
        user = StubClient.User(1, "a.dmin")

        AsyncClientMock = asynctest.create_autospec(AsyncClient)

        transaction = asynctest.MagicMock()
        transaction.__aenter__.side_effect = AsyncClientMock

        cursor = asynctest.MagicMock()
        cursor.__aiter__.return_value = [user]

        client = AsyncClientMock()
        client.new_transaction.return_value = transaction
        client.get_users_cursor.return_value = cursor

        cache = {}

        # The user has been added to the cache
        nb_added = await cache_users_with_cursor(client, cache)

        self.assertEqual(nb_added, 1)
        self.assertEqual(cache[1], user)

        # The user was already there
        nb_added = await cache_users_with_cursor(client, cache)
        self.assertEqual(nb_added, 0)
        self.assertEqual(cache[1], user)


class TestCachingIsLogged(asynctest.TestCase):
    async def test_with_context_manager(self):
        client = asynctest.Mock(AsyncClient())
        cache = {}

        with asynctest.patch("logging.debug") as debug_mock:
            await cache_users_async(client, cache)

        debug_mock.assert_called()

    @asynctest.patch("logging.error")
    @asynctest.patch("logging.debug")
    async def test_with_decorator(self, debug_mock, error_mock):
        client = asynctest.Mock(AsyncClient())
        cache = {}

        await cache_users_async(client, cache)

        debug_mock.assert_called()
        error_mock.assert_not_called()