tutorial/mocking_io.pyΒΆ

# coding: utf-8
import asyncio
import socket

import asynctest


class TestMockASocket(asynctest.TestCase):
    async def test_read_and_write_from_socket(self):
        socket_mock = asynctest.SocketMock()
        socket_mock.type = socket.SOCK_STREAM

        recv_data = iter((
            b"some data read",
            b"some other",
            b" ...and the last",
        ))

        recv_buffer = bytearray()

        def recv_side_effect(max_bytes):
            nonlocal recv_buffer

            if not recv_buffer:
                try:
                    recv_buffer.extend(next(recv_data))
                    asynctest.set_read_ready(socket_mock, self.loop)
                except StopIteration:
                    # nothing left
                    pass

            data = recv_buffer[:max_bytes]
            recv_buffer = recv_buffer[max_bytes:]

            if recv_buffer:
                # Some more data to read
                asynctest.set_read_ready(socket_mock, self.loop)

            return data

        def send_side_effect(data):
            asynctest.set_read_ready(socket_mock, self.loop)
            return len(data)

        socket_mock.recv.side_effect = recv_side_effect
        socket_mock.send.side_effect = send_side_effect

        reader, writer = await asyncio.open_connection(sock=socket_mock)

        writer.write(b"a request?")
        self.assertEqual(b"some", await reader.read(4))
        self.assertEqual(b" data read", await reader.read(10))
        self.assertEqual(b"some other ...and the last", await reader.read())