{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Common problems\n", "In this section we will see common problems when mocking our code." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `patch` not patching the correct object's attribute\n", "It is very common to spend a lot of time figuring out why `patch` does not working. \n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting database.py\n" ] } ], "source": [ "%%writefile database.py\n", "class DBConnection:\n", " def __init__(self, dsn):\n", " print('Connected to real database')\n", " self.dsn = dsn\n", "\n", " def cursor(self):\n", " return Cursor()\n", "\n", " def commit(self):\n", " print('Saved changes')\n", "\n", "class Cursor:\n", " def execute(self, query):\n", " print(\"Executed query={}\".format(query))\n", "\n", "def connect(dsn):\n", " return DBConnection(dsn)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting script.py\n" ] } ], "source": [ "%%writefile script.py\n", "from database import connect\n", "\n", "def clean_db():\n", " conn = connect(dsn=\"user='123', password='xxx', host='hotels.prod.aws.com'\")\n", " cursor = conn.cursor()\n", " cursor.execute('TRUNCATE clickouts')\n", " cursor.execute('TRUNCATE images')\n", " conn.commit()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Connected to real database\n", "Executed query=TRUNCATE clickouts\n", "Executed query=TRUNCATE images\n", "Saved changes\n" ] }, { "ename": "AssertionError", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 10\u001b[0m ]\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mtest_clean_db\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mtest_clean_db\u001b[0;34m()\u001b[0m\n\u001b[1;32m 7\u001b[0m assert db_mock().cursor().method_calls == [\n\u001b[1;32m 8\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'TRUNCATE clickouts'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mcall\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'TRUNCATE images'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m ]\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAssertionError\u001b[0m: " ] } ], "source": [ "from unittest.mock import Mock, call, patch\n", "from script import clean_db\n", "\n", "def test_clean_db():\n", " with patch('database.connect') as db_mock:\n", " clean_db()\n", " assert db_mock().cursor().method_calls == [\n", " call.execute('TRUNCATE clickouts'),\n", " call.execute('TRUNCATE images')\n", " ]\n", "\n", "test_clean_db()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "... But we patched `database.connect`! Why did it connect to the real database and execute queries?\n", "\n", "**Explanation: `patch('database.connect')` is not patching `script.connect` ** because it's a copy of `database.connect` reference.\n", "\n", "**Before patching**\n", "```python\n", "database.connect = \n", "database.Database = \n", "\n", "script.connect = \n", "script.clean_db = \n", "```\n", "\n", "**After patching**\n", "```python\n", "database.connect = \n", "database.Database = \n", "\n", "script.connect = \n", "scriptclean_db = \n", "```\n", "\n", "What `patch('database.connect')` does is patch the attribute `connect` of `database` module:\n", "\n", "```python\n", "def patch('database.connect'):\n", " import database\n", " original_function = database.connect\n", " database.connect = db_mock = Mock()\n", " yield db_mock\n", " database.connect = original_function\n", "```\n", "\n", "Attribute `connect` of `script.py` module is a copy of the original reference to ``.\n", "\n", "Be careful or reference copies you import on your module.\n", "\n", "Possible fixes:\n", "1. In `script.py`, replace `from database import connect` to `import database` and use `database.connect`.\n", "2. Use `patch('script.connect')`\n", "3. Use `patch('database.DBConnection')`" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from unittest.mock import Mock, patch, call\n", "from script import clean_db\n", "\n", "def test_clean_db():\n", " with patch('database.DBConnection') as db_mock:\n", " clean_db()\n", " assert db_mock().cursor().method_calls == [\n", " call.execute('TRUNCATE clickouts'),\n", " call.execute('TRUNCATE images')\n", " ]\n", "\n", "test_clean_db()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use `self` in `side_effect`\n", "A common headache appears when trying to patch a class method with a custom function which receives `self` as a parameter, like so:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from unittest.mock import Mock, MagicMock, patch\n", "class Table:\n", " def __init__(self, name):\n", " self.table_name = name\n", " \n", " def get_rows(self):\n", " print(\"Retrieve rows from database\")\n", " return [1, 2, 3]\n", "\n", "def get_all_data():\n", " users = Table('users')\n", " jobs = Table('jobs')\n", " return {'users': users.get_rows(),\n", " 'jobs': jobs.get_rows()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One way to test the function `get_all_data` would be to patch `Table.get_rows` function to make it return pre-defined rows based on the value of `self.table_name`, like so:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "row_data = {'users': ['user_row_1',\n", " 'user_row_2'],\n", " 'jobs': ['job_row_1',\n", " 'job_row_2',\n", " 'job_row_3']}\n", "\n", "# with patch.object(Table, 'get_rows', side_effect=WHAT DO WE INSERT HERE?):" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The problem comes when we want to define the `side_effect`. If you try this, it wouldn't work:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "() missing 1 required positional argument: 'self'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobject\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'get_rows'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mside_effect\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mrow_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mget_all_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mrow_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mget_all_data\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0musers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'users'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mjobs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'jobs'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m return {'users': users.get_rows(),\n\u001b[0m\u001b[1;32m 14\u001b[0m 'jobs': jobs.get_rows()}\n", "\u001b[0;32m/usr/local/lib/python3.6/unittest/mock.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(_mock_self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 937\u001b[0m \u001b[0;31m# in the signature\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 938\u001b[0m \u001b[0m_mock_self\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mock_check_sig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 939\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_mock_self\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mock_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 940\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 941\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/unittest/mock.py\u001b[0m in \u001b[0;36m_mock_call\u001b[0;34m(_mock_self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1003\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1004\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1005\u001b[0;31m \u001b[0mret_val\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meffect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1006\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1007\u001b[0m if (self._mock_wraps is not None and\n", "\u001b[0;31mTypeError\u001b[0m: () missing 1 required positional argument: 'self'" ] } ], "source": [ "with patch.object(Table, 'get_rows', side_effect=lambda self: row_data[self.table_name]):\n", " assert get_all_data() == row_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The parameter `self` is not passed to our side_effect, and we want it.\n", "\n", "If we check what `get_rows` is, we will see:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", ">\n", "\n", "\n", "\n" ] } ], "source": [ "print(Table.get_rows)\n", "print(Table('users').get_rows)\n", "with patch.object(Table, 'get_rows', side_effect=lambda self: row_data[self.table_name]) as mock_get_rows:\n", " print(mock_get_rows)\n", " print(Table.get_rows)\n", " print(Table('users').get_rows)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The instances `get_rows` method are bounded to the instance.\n", "The difference between a `function` and a `bound method` is that the `self` (instance object) parameter is automatically added to the arguments being called in the bounded method when calling it." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Retrieve rows from database\n", "Retrieve rows from database\n" ] }, { "data": { "text/plain": [ "[1, 2, 3]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "users = Table('users')\n", "Table.get_rows(users) # Calling it like this requires you to pass `self` (instance) attribute.\n", "users.get_rows() # `self` is automatically passed, because it's a bounded method" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What mocking library does is:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['user_row_1', 'user_row_2']\n", "\n" ] }, { "ename": "TypeError", "evalue": "() missing 1 required positional argument: 'self'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_rows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0musers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# This works\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0musers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0musers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_rows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# This doesn't\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mTable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_rows\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moriginal_get_rows\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/unittest/mock.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(_mock_self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 937\u001b[0m \u001b[0;31m# in the signature\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 938\u001b[0m \u001b[0m_mock_self\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mock_check_sig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 939\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_mock_self\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mock_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 940\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 941\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/unittest/mock.py\u001b[0m in \u001b[0;36m_mock_call\u001b[0;34m(_mock_self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1003\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1004\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1005\u001b[0;31m \u001b[0mret_val\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meffect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1006\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1007\u001b[0m if (self._mock_wraps is not None and\n", "\u001b[0;31mTypeError\u001b[0m: () missing 1 required positional argument: 'self'" ] } ], "source": [ "original_get_rows = Table.get_rows\n", "try:\n", " mock_get_rows = MagicMock()\n", " mock_get_rows.side_effect = lambda self: row_data[self.table_name] # function, not bounded!\n", " Table.get_rows = mock_get_rows\n", "\n", " users = Table('users')\n", " print(Table.get_rows(users)) # This works\n", " print(users.get_rows)\n", " print(users.get_rows()) # This doesn't\n", "finally:\n", " Table.get_rows = original_get_rows" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Solutions**\n", "\n", "**Option 1** - Use `patch.object` to temporarily assign a new `get_rows` function without a Mock object:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " at 0x7fe29554ed08>\n", " at 0x7fe29554ed08>\n", " of <__main__.Table object at 0x7fe294cf72e8>>\n" ] } ], "source": [ "with patch.object(Table, 'get_rows', new=lambda self: row_data[self.table_name]) as mock_get_rows:\n", " assert get_all_data() == row_data\n", " print(mock_get_rows)\n", " print(Table.get_rows)\n", " print(Table('aaa').get_rows)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Pros**: Short, and works \n", "**Cons**: You lose all the `call` history being done in `get_rows` that you get if you used `Mock` as `new`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Option 2**: Create one mocked `Table` instance for each expected. This way you can pre-define the return_values for each instance and not need `self`." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KeyError('table that is not mocked',)\n" ] } ], "source": [ "from unittest.mock import patch, MagicMock, Mock, call\n", "\n", "def mocked_table_instance(table_name, rows):\n", " table_inst = MagicMock(table_name=table_name)\n", " table_inst.get_rows.return_value = rows\n", " return table_inst\n", "\n", "mocked_tables = {'users': mocked_table_instance('users', row_data['users']),\n", " 'jobs': mocked_table_instance('jobs', row_data['jobs'])}\n", "with patch('__main__.Table', side_effect=lambda table_name: mocked_tables[table_name]) as table_class_mock:\n", " assert get_all_data() == row_data\n", " assert table_class_mock.mock_calls == [call('users'),\n", " call('jobs')]\n", " assert mocked_tables['users'].method_calls == [call.get_rows()]\n", " assert mocked_tables['jobs'].method_calls == [call.get_rows()]\n", " \n", " try:\n", " Table('table that is not mocked')\n", " except KeyError as e:\n", " print(repr(e))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Pros**: Has call history. Common `mocked_table_instance` function can be used in multiple tests, centralised way of mocking `Table`. \n", "**Cons**: Longer patching." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Option 3**: Create a class which simulates `Table`. " ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "class TableMock:\n", " def __init__(self, table_name):\n", " self.table_name = table_name\n", " def mock_set_rows(self, rows):\n", " self.rows = rows\n", " def get_rows(self):\n", " return self.rows\n", "\n", "mocked_tables = {}\n", "for name, rows in row_data.items():\n", " mocked_tables[name] = TableMock(name)\n", " mocked_tables[name].mock_set_rows(rows)\n", "with patch('__main__.Table', side_effect=lambda table_name: mocked_tables[table_name]) as table_class_mock:\n", " assert Table('users').get_rows() == row_data['users']\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Pros**: Easiest to extend. Depending on how you do it, every class method is notimplemented by default (good if you forgot to patch a method which touches real files/databases). \n", "**Cons**: No call history" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unit test a class function while patching all others\n", "Let's say you want to unit test the function `Database.copy_from(other_db)` does some calls:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Database:\n", " def copy_from(self, other_db, drop_all=False):\n", " if drop_all:\n", " self.delete_all()\n", " self.create()\n", " self.add_users(other_db.get_users())\n", " self.add_jobs(other_db.get_jobs())\n", " self.add_categories(other_db.get_categories())\n", " self.commit()\n", " \n", " # Ugly way of defining all other functions\n", " def noop():\n", " pass\n", " delete_all = create = add_users = add_jobs = add_categories = commit = noop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One usual way of doing it would be to:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from unittest.mock import patch, DEFAULT\n", "with patch.multiple(Database, delete_all=DEFAULT, add_users=DEFAULT, add_jobs=DEFAULT, add_categories=DEFAULT, commit=DEFAULT) as mock_db:\n", " other_db_mock = Mock()\n", " db = Database()\n", " db.copy_from(other_db_mock)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are a few problems here:\n", "1. `patch` line is very long. Same would happen if we use multiple `patch.object(Database, function=blabla)` (even longer). Gets worse when having to define `return_value` and `side_effect`\n", "2. No order of methods calls. `method_calls` and `mock_calls` not available, because `Database` class is not mocked.\n", "3. Developer may forgot to patch a `Database` function that should never be executed in tests after refactoring/adding more `add_xxxx` functions.\n", "\n", "A new way of unittesting a single method from a class while automatically patching all others would be to call `Database.copy_from` (unbounded method!) with a `Mock` object:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": true }, "outputs": [], "source": [ "db_mock = Mock()\n", "other_db_mock = Mock()\n", "Database.copy_from(db_mock, other_db_mock)\n", "assert db_mock.method_calls == [\n", " call.add_users(other_db_mock.get_users()),\n", " call.add_jobs(other_db_mock.get_jobs()),\n", " call.add_categories(other_db_mock.get_categories()),\n", " call.commit()\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In case you want to define custom return_values or side_effects to their methods, it is pretty easy and clean:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": true }, "outputs": [], "source": [ "db_mock = Mock()\n", "db_mock.delete_all.side_effect = Exception(\"UNEXPECTED CALL!\")\n", "db_mock.create.side_effect = Exception(\"UNEXPECTED CALL!\")\n", "db_mock.commit.return_value = True\n", "\n", "other_db_mock = Mock()\n", "Database.copy_from(db_mock, other_db_mock)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In case you want `db_mock` to have all attributes that are created/initialized in `__init__`:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": true }, "outputs": [], "source": [ "db_mock = Mock()\n", "Database.__init__(db_mock)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unit test a classmethod and staticmethod\n", "Following previous approach, it is a bit tricky to call `Database.method` if the `method` is a `staticmethod` or `classmethod`:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### classmethod" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Algorithms:\n", " @classmethod\n", " def cfib(cls, x):\n", " print(\"> called {}.cfib({})\".format(cls, x))\n", " if x < 2:\n", " return x\n", " return cls.cfib(x-1) + cls.cfib(x-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will see that `Algorithms.cfib` is bounded to the class:" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">\n" ] } ], "source": [ "m = Mock()\n", "print(Algorithms.cfib)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the function is already bounded (to the class), we can't pass our own `cls` object:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "scrolled": true }, "outputs": [ { "ename": "TypeError", "evalue": "cfib() takes 2 positional arguments but 3 were given", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mAlgorithms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcfib\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mTypeError\u001b[0m: cfib() takes 2 positional arguments but 3 were given" ] } ], "source": [ "m = Mock()\n", "Algorithms.cfib(m, 5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The solution is to **unbound** the function, which can be done by accessing the bounded method's attribute `__func__`:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">\n", "\n" ] } ], "source": [ "print(Algorithms.cfib)\n", "print(Algorithms.cfib.__func__)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "> called .cfib(5)\n", "\n", "[call.cfib(4), call.cfib(3)]\n" ] } ], "source": [ "m = MagicMock() # So that m.cfib returns a MagicMock, which you can sum with another MagicMock\n", "print(Algorithms.cfib.__func__(m, 5))\n", "print(m.method_calls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another issue comes when you try to make `m.cfib` work like `Algorithms.cfib` does:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "> called .cfib(3)\n", "> called .cfib(2)\n", "> called .cfib(1)\n", "> called .cfib(0)\n", "> called .cfib(1)\n", "2\n", "[]\n" ] } ], "source": [ "m = MagicMock()\n", "m.cfib = Algorithms.cfib\n", "print(Algorithms.cfib.__func__(m, 3))\n", "print(m.method_calls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Does not work** because `cls` is not our `mock` object. The object `m.cfib` is a function **bounded to `Algorithms`**, not our `mock`!\n", "\n", "It is possible to change `cfib` and make it a function bounded to `m`:" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "> called .cfib(3)\n", "> called .cfib(2)\n", "> called .cfib(1)\n", "> called .cfib(0)\n", "> called .cfib(1)\n", "2\n", "[]\n" ] } ], "source": [ "import types\n", "m = MagicMock()\n", "m.cfib = types.MethodType(Algorithms.cfib.__func__, m)\n", "print(m.cfib(3))\n", "print(m.method_calls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Perfect, but where is our call history? **If you want call history, you must use side_effect or return_value**.\n", "\n", "**Solution:**" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "> called .cfib(4)\n", "> called .cfib(3)\n", "> called .cfib(2)\n", "> called .cfib(1)\n", "> called .cfib(0)\n", "> called .cfib(1)\n", "> called .cfib(2)\n", "> called .cfib(1)\n", "> called .cfib(0)\n", "3\n", "[call.cfib(4),\n", " call.cfib(3),\n", " call.cfib(2),\n", " call.cfib(1),\n", " call.cfib(0),\n", " call.cfib(1),\n", " call.cfib(2),\n", " call.cfib(1),\n", " call.cfib(0)]\n" ] } ], "source": [ "m = MagicMock()\n", "m.cfib.side_effect = types.MethodType(Algorithms.cfib.__func__, m)\n", "print(m.cfib(4))\n", "print(m.method_calls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### staticmethod\n", "Use `wraps` to track history of calls on `Algorithms.fib` (with recursion too!):" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Algorithms:\n", " @staticmethod\n", " def fib(x):\n", " if x < 2:\n", " return x\n", " return Algorithms.fib(x-1) + Algorithms.fib(x-2)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[call(4),\n", " call(3),\n", " call(2),\n", " call(1),\n", " call(0),\n", " call(1),\n", " call(2),\n", " call(1),\n", " call(0)]\n", "[call(4),\n", " call(3),\n", " call(2),\n", " call(1),\n", " call(0),\n", " call(1),\n", " call(2),\n", " call(1),\n", " call(0)]\n" ] } ], "source": [ "with patch.object(Algorithms, 'fib', wraps=Algorithms.fib) as fib_mock:\n", " Algorithms.fib(4)\n", " print(fib_mock.mock_calls)\n", "\n", "with patch.object(Algorithms, 'fib', side_effect=Algorithms.fib) as fib_mock:\n", " Algorithms.fib(4)\n", " print(fib_mock.mock_calls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In case you don't want recursion and just check that `fib(x)` calls `fib(x-1)` and `fib(x-2)`:" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[call(3), call(2)]\n" ] } ], "source": [ "orig_fib = Algorithms.fib\n", "with patch.object(Algorithms, 'fib', return_value=0) as fib_mock:\n", " orig_fib(4)\n", " print(fib_mock.mock_calls)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Attaching Mocks as attributes\n", "When you attach a mock as an attribute of another mock, it becomes a \"child\" of that mock. Calls to the child are recorded in the `method_calls` and `mock_calls` attributes of the parent.\n", "If the child `Mock` has a `name`, the parent will not see this child method_calls:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[call.child1('abc')]\n", "[call.child1('abc')]\n" ] } ], "source": [ "m = Mock(name='parent')\n", "child1 = Mock()\n", "child2 = Mock(name='child_two')\n", "m.child1 = child1\n", "m.child2 = child2\n", "\n", "child1('abc')\n", "child2(1, 2, 3)\n", "print(m.method_calls)\n", "print(m.mock_calls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Mocks created by `patch()` are automatically given names:" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "[]\n", "[call.now()]\n" ] } ], "source": [ "import datetime\n", "m = Mock()\n", "with patch('datetime.datetime') as child1:\n", " m.datetime = child1\n", " print(datetime.datetime.now())\n", "print(m.method_calls)\n", "print(m.datetime.method_calls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To attach mocks that have names to a parent, you can use the `Mock` method `attach_mock`:" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "[call.datetime.now()]\n", "[call.now()]\n" ] } ], "source": [ "import datetime\n", "m = Mock()\n", "with patch('datetime.datetime') as child1:\n", " m.attach_mock(child1, 'datetime')\n", " print(datetime.datetime.now())\n", "print(m.method_calls)\n", "print(m.datetime.method_calls)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Mocking/patching `async` methods\n", "To test projects which work with `async` and `coroutines`, I recommend using `asynctest` library: https://pypi.python.org/pypi/asynctest/0.5.0\n", "\n", "This library has a new mock object `CoroutineMock` which lets you define `return_value` and `side_effect` of your functions without having to worry about them being `async`, `Future` objects or anything." ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ".." ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "----------------------------------------------------------------------\n", "Ran 2 tests in 0.016s\n", "\n", "OK\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import asynctest\n", "from asynctest import patch, Mock\n", "import asyncio\n", "\n", "loop = asyncio.new_event_loop()\n", "asyncio.set_event_loop(loop)\n", "\n", "class AsyncThing:\n", " async def method(self):\n", " asyncio.sleep(50)\n", " return 5\n", " \n", " def normal_method(self):\n", " return 123\n", "\n", "class TestSomething(asynctest.TestCase):\n", " use_default_loop = True\n", " async def test_something(self):\n", " a = AsyncThing()\n", " with patch('asyncio.sleep'):\n", " x = await a.method()\n", " assert x == 5\n", " \n", " async def test_class_mock(self):\n", " # Using `spec` makes it create CoroutineMock or MagicMock, depending on if the method is async or not.\n", " # Make sure to import patch from asynctest!\n", " with patch('__main__.AsyncThing', spec=AsyncThing) as asyncthing_mock:\n", " print(asyncthing_mock.method)\n", " print(asyncthing_mock.normal_method)\n", "\n", "ts = TestSomething()\n", "suite = asynctest.TestLoader().loadTestsFromModule(ts)\n", "asynctest.TextTestRunner().run(suite)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One missing feature from `asynctest` is the behaviour with `async with` (`__aenter__` and `__aexit__` methods). By default, using `asynctest` would not work.\n", "\n", "Current \"best known\" solution is to create a new `Mock` class:" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "." ] }, { "name": "stdout", "output_type": "stream", "text": [ "first \n", "second 5\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "----------------------------------------------------------------------\n", "Ran 1 test in 0.007s\n", "\n", "OK\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class AsyncContextManagerMock(MagicMock):\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", "\n", " setattr(self, 'aenter_return', kwargs.get('aenter_return', MagicMock()))\n", " setattr(self, 'aexit_return', kwargs.get('aexit_return', None))\n", "\n", " async def __aenter__(self):\n", " return self.aenter_return\n", "\n", " async def __aexit__(self, exc_type, exc_value, traceback):\n", " return self.aexit_return\n", "\n", "class TestSomething(asynctest.TestCase):\n", " use_default_loop = True\n", " async def test_async_with(self):\n", " async with AsyncContextManagerMock() as mock:\n", " print('first', mock)\n", " async with AsyncContextManagerMock(aenter_return=5) as value:\n", " print('second', value)\n", "\n", "ts = TestSomething()\n", "suite = asynctest.TestLoader().loadTestsFromModule(ts)\n", "asynctest.TextTestRunner().run(suite)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 2 }