# Common problems
In this section we will see common problems when mocking our code.

## `patch` not patching the correct object's attribute
It is very common to spend a lot of time figuring out why `patch` does not working. 


In [1]:
%%writefile database.py
class DBConnection:
 def __init__(self, dsn):
 print('Connected to real database')
 self.dsn = dsn

 def cursor(self):
 return Cursor()

 def commit(self):
 print('Saved changes')

class Cursor:
 def execute(self, query):
 print("Executed query={}".format(query))

def connect(dsn):
 return DBConnection(dsn)

Overwriting database.py


In [2]:
%%writefile script.py
from database import connect

def clean_db():
 conn = connect(dsn="user='123', password='xxx', host='hotels.prod.aws.com'")
 cursor = conn.cursor()
 cursor.execute('TRUNCATE clickouts')
 cursor.execute('TRUNCATE images')
 conn.commit()

Overwriting script.py


In [3]:
from unittest.mock import Mock, call, patch
from script import clean_db

def test_clean_db():
 with patch('database.connect') as db_mock:
 clean_db()
 assert db_mock().cursor().method_calls == [
 call.execute('TRUNCATE clickouts'),
 call.execute('TRUNCATE images')
 ]

test_clean_db()

Connected to real database
Executed query=TRUNCATE clickouts
Executed query=TRUNCATE images
Saved changes


AssertionError: 

... But we patched `database.connect`! Why did it connect to the real database and execute queries?

**Explanation: `patch('database.connect')` is not patching `script.connect` ** because it's a copy of `database.connect` reference.

**Before patching**
```python
database.connect = 
database.Database = 

script.connect = 
script.clean_db = 
```

**After patching**
```python
database.connect = 
database.Database = 

script.connect = 
scriptclean_db = 
```

What `patch('database.connect')` does is patch the attribute `connect` of `database` module:

```python
def patch('database.connect'):
 import database
 original_function = database.connect
 database.connect = db_mock = Mock()
 yield db_mock
 database.connect = original_function
```

Attribute `connect` of `script.py` module is a copy of the original reference to ``.

Be careful or reference copies you import on your module.

Possible fixes:
1. In `script.py`, replace `from database import connect` to `import database` and use `database.connect`.
2. Use `patch('script.connect')`
3. Use `patch('database.DBConnection')`

In [4]:
from unittest.mock import Mock, patch, call
from script import clean_db

def test_clean_db():
 with patch('database.DBConnection') as db_mock:
 clean_db()
 assert db_mock().cursor().method_calls == [
 call.execute('TRUNCATE clickouts'),
 call.execute('TRUNCATE images')
 ]

test_clean_db()

## Use `self` in `side_effect`
A common headache appears when trying to patch a class method with a custom function which receives `self` as a parameter, like so:

In [6]:
from unittest.mock import Mock, MagicMock, patch
class Table:
 def __init__(self, name):
 self.table_name = name
 
 def get_rows(self):
 print("Retrieve rows from database")
 return [1, 2, 3]

def get_all_data():
 users = Table('users')
 jobs = Table('jobs')
 return {'users': users.get_rows(),
 'jobs': jobs.get_rows()}

One way to test the function `get_all_data` would be to patch `Table.get_rows` function to make it return pre-defined rows based on the value of `self.table_name`, like so:

In [7]:
row_data = {'users': ['user_row_1',
 'user_row_2'],
 'jobs': ['job_row_1',
 'job_row_2',
 'job_row_3']}

# with patch.object(Table, 'get_rows', side_effect=WHAT DO WE INSERT HERE?):

The problem comes when we want to define the `side_effect`. If you try this, it wouldn't work:

In [9]:
with patch.object(Table, 'get_rows', side_effect=lambda self: row_data[self.table_name]):
 assert get_all_data() == row_data

TypeError: () missing 1 required positional argument: 'self'

The parameter `self` is not passed to our side_effect, and we want it.

If we check what `get_rows` is, we will see:

In [10]:
print(Table.get_rows)
print(Table('users').get_rows)
with patch.object(Table, 'get_rows', side_effect=lambda self: row_data[self.table_name]) as mock_get_rows:
 print(mock_get_rows)
 print(Table.get_rows)
 print(Table('users').get_rows)


>





The instances `get_rows` method are bounded to the instance.
The difference between a `function` and a `bound method` is that the `self` (instance object) parameter is automatically added to the arguments being called in the bounded method when calling it.

In [11]:
users = Table('users')
Table.get_rows(users) # Calling it like this requires you to pass `self` (instance) attribute.
users.get_rows() # `self` is automatically passed, because it's a bounded method

Retrieve rows from database
Retrieve rows from database


[1, 2, 3]

What mocking library does is:

In [15]:
original_get_rows = Table.get_rows
try:
 mock_get_rows = MagicMock()
 mock_get_rows.side_effect = lambda self: row_data[self.table_name] # function, not bounded!
 Table.get_rows = mock_get_rows

 users = Table('users')
 print(Table.get_rows(users)) # This works
 print(users.get_rows)
 print(users.get_rows()) # This doesn't
finally:
 Table.get_rows = original_get_rows

['user_row_1', 'user_row_2']



TypeError: () missing 1 required positional argument: 'self'

**Solutions**

**Option 1** - Use `patch.object` to temporarily assign a new `get_rows` function without a Mock object:

In [18]:
with patch.object(Table, 'get_rows', new=lambda self: row_data[self.table_name]) as mock_get_rows:
 assert get_all_data() == row_data
 print(mock_get_rows)
 print(Table.get_rows)
 print(Table('aaa').get_rows)

 at 0x7fe29554ed08>
 at 0x7fe29554ed08>
 of <__main__.Table object at 0x7fe294cf72e8>>


**Pros**: Short, and works 
**Cons**: You lose all the `call` history being done in `get_rows` that you get if you used `Mock` as `new`.

**Option 2**: Create one mocked `Table` instance for each expected. This way you can pre-define the return_values for each instance and not need `self`.

In [21]:
from unittest.mock import patch, MagicMock, Mock, call

def mocked_table_instance(table_name, rows):
 table_inst = MagicMock(table_name=table_name)
 table_inst.get_rows.return_value = rows
 return table_inst

mocked_tables = {'users': mocked_table_instance('users', row_data['users']),
 'jobs': mocked_table_instance('jobs', row_data['jobs'])}
with patch('__main__.Table', side_effect=lambda table_name: mocked_tables[table_name]) as table_class_mock:
 assert get_all_data() == row_data
 assert table_class_mock.mock_calls == [call('users'),
 call('jobs')]
 assert mocked_tables['users'].method_calls == [call.get_rows()]
 assert mocked_tables['jobs'].method_calls == [call.get_rows()]
 
 try:
 Table('table that is not mocked')
 except KeyError as e:
 print(repr(e))

KeyError('table that is not mocked',)


**Pros**: Has call history. Common `mocked_table_instance` function can be used in multiple tests, centralised way of mocking `Table`. 
**Cons**: Longer patching.

**Option 3**: Create a class which simulates `Table`. 

In [24]:
class TableMock:
 def __init__(self, table_name):
 self.table_name = table_name
 def mock_set_rows(self, rows):
 self.rows = rows
 def get_rows(self):
 return self.rows

mocked_tables = {}
for name, rows in row_data.items():
 mocked_tables[name] = TableMock(name)
 mocked_tables[name].mock_set_rows(rows)
with patch('__main__.Table', side_effect=lambda table_name: mocked_tables[table_name]) as table_class_mock:
 assert Table('users').get_rows() == row_data['users']


**Pros**: Easiest to extend. Depending on how you do it, every class method is notimplemented by default (good if you forgot to patch a method which touches real files/databases). 
**Cons**: No call history

## Unit test a class function while patching all others
Let's say you want to unit test the function `Database.copy_from(other_db)` does some calls:

In [25]:
class Database:
 def copy_from(self, other_db, drop_all=False):
 if drop_all:
 self.delete_all()
 self.create()
 self.add_users(other_db.get_users())
 self.add_jobs(other_db.get_jobs())
 self.add_categories(other_db.get_categories())
 self.commit()
 
 # Ugly way of defining all other functions
 def noop():
 pass
 delete_all = create = add_users = add_jobs = add_categories = commit = noop

One usual way of doing it would be to:

In [26]:
from unittest.mock import patch, DEFAULT
with patch.multiple(Database, delete_all=DEFAULT, add_users=DEFAULT, add_jobs=DEFAULT, add_categories=DEFAULT, commit=DEFAULT) as mock_db:
 other_db_mock = Mock()
 db = Database()
 db.copy_from(other_db_mock)

There are a few problems here:
1. `patch` line is very long. Same would happen if we use multiple `patch.object(Database, function=blabla)` (even longer). Gets worse when having to define `return_value` and `side_effect`
2. No order of methods calls. `method_calls` and `mock_calls` not available, because `Database` class is not mocked.
3. Developer may forgot to patch a `Database` function that should never be executed in tests after refactoring/adding more `add_xxxx` functions.

A new way of unittesting a single method from a class while automatically patching all others would be to call `Database.copy_from` (unbounded method!) with a `Mock` object:

In [27]:
db_mock = Mock()
other_db_mock = Mock()
Database.copy_from(db_mock, other_db_mock)
assert db_mock.method_calls == [
 call.add_users(other_db_mock.get_users()),
 call.add_jobs(other_db_mock.get_jobs()),
 call.add_categories(other_db_mock.get_categories()),
 call.commit()
]

In case you want to define custom return_values or side_effects to their methods, it is pretty easy and clean:

In [28]:
db_mock = Mock()
db_mock.delete_all.side_effect = Exception("UNEXPECTED CALL!")
db_mock.create.side_effect = Exception("UNEXPECTED CALL!")
db_mock.commit.return_value = True

other_db_mock = Mock()
Database.copy_from(db_mock, other_db_mock)

In case you want `db_mock` to have all attributes that are created/initialized in `__init__`:

In [29]:
db_mock = Mock()
Database.__init__(db_mock)

## Unit test a classmethod and staticmethod
Following previous approach, it is a bit tricky to call `Database.method` if the `method` is a `staticmethod` or `classmethod`:

### classmethod

In [30]:
class Algorithms:
 @classmethod
 def cfib(cls, x):
 print("> called {}.cfib({})".format(cls, x))
 if x < 2:
 return x
 return cls.cfib(x-1) + cls.cfib(x-2)

We will see that `Algorithms.cfib` is bounded to the class:

In [31]:
m = Mock()
print(Algorithms.cfib)

>


Since the function is already bounded (to the class), we can't pass our own `cls` object:

In [33]:
m = Mock()
Algorithms.cfib(m, 5)

TypeError: cfib() takes 2 positional arguments but 3 were given

The solution is to **unbound** the function, which can be done by accessing the bounded method's attribute `__func__`:

In [34]:
print(Algorithms.cfib)
print(Algorithms.cfib.__func__)

>



In [36]:
m = MagicMock() # So that m.cfib returns a MagicMock, which you can sum with another MagicMock
print(Algorithms.cfib.__func__(m, 5))
print(m.method_calls)

> called .cfib(5)

[call.cfib(4), call.cfib(3)]


Another issue comes when you try to make `m.cfib` work like `Algorithms.cfib` does:

In [39]:
m = MagicMock()
m.cfib = Algorithms.cfib
print(Algorithms.cfib.__func__(m, 3))
print(m.method_calls)

> called .cfib(3)
> called .cfib(2)
> called .cfib(1)
> called .cfib(0)
> called .cfib(1)
2
[]


**Does not work** because `cls` is not our `mock` object. The object `m.cfib` is a function **bounded to `Algorithms`**, not our `mock`!

It is possible to change `cfib` and make it a function bounded to `m`:

In [40]:
import types
m = MagicMock()
m.cfib = types.MethodType(Algorithms.cfib.__func__, m)
print(m.cfib(3))
print(m.method_calls)

> called .cfib(3)
> called .cfib(2)
> called .cfib(1)
> called .cfib(0)
> called .cfib(1)
2
[]


Perfect, but where is our call history? **If you want call history, you must use side_effect or return_value**.

**Solution:**

In [41]:
m = MagicMock()
m.cfib.side_effect = types.MethodType(Algorithms.cfib.__func__, m)
print(m.cfib(4))
print(m.method_calls)

> called .cfib(4)
> called .cfib(3)
> called .cfib(2)
> called .cfib(1)
> called .cfib(0)
> called .cfib(1)
> called .cfib(2)
> called .cfib(1)
> called .cfib(0)
3
[call.cfib(4),
 call.cfib(3),
 call.cfib(2),
 call.cfib(1),
 call.cfib(0),
 call.cfib(1),
 call.cfib(2),
 call.cfib(1),
 call.cfib(0)]


### staticmethod
Use `wraps` to track history of calls on `Algorithms.fib` (with recursion too!):

In [43]:
class Algorithms:
 @staticmethod
 def fib(x):
 if x < 2:
 return x
 return Algorithms.fib(x-1) + Algorithms.fib(x-2)

In [44]:
with patch.object(Algorithms, 'fib', wraps=Algorithms.fib) as fib_mock:
 Algorithms.fib(4)
 print(fib_mock.mock_calls)

with patch.object(Algorithms, 'fib', side_effect=Algorithms.fib) as fib_mock:
 Algorithms.fib(4)
 print(fib_mock.mock_calls)

[call(4),
 call(3),
 call(2),
 call(1),
 call(0),
 call(1),
 call(2),
 call(1),
 call(0)]
[call(4),
 call(3),
 call(2),
 call(1),
 call(0),
 call(1),
 call(2),
 call(1),
 call(0)]


In case you don't want recursion and just check that `fib(x)` calls `fib(x-1)` and `fib(x-2)`:

In [45]:
orig_fib = Algorithms.fib
with patch.object(Algorithms, 'fib', return_value=0) as fib_mock:
 orig_fib(4)
 print(fib_mock.mock_calls)

[call(3), call(2)]


## Attaching Mocks as attributes
When you attach a mock as an attribute of another mock, it becomes a "child" of that mock. Calls to the child are recorded in the `method_calls` and `mock_calls` attributes of the parent.
If the child `Mock` has a `name`, the parent will not see this child method_calls:

In [46]:
m = Mock(name='parent')
child1 = Mock()
child2 = Mock(name='child_two')
m.child1 = child1
m.child2 = child2

child1('abc')
child2(1, 2, 3)
print(m.method_calls)
print(m.mock_calls)

[call.child1('abc')]
[call.child1('abc')]


Mocks created by `patch()` are automatically given names:

In [47]:
import datetime
m = Mock()
with patch('datetime.datetime') as child1:
 m.datetime = child1
 print(datetime.datetime.now())
print(m.method_calls)
print(m.datetime.method_calls)


[]
[call.now()]


To attach mocks that have names to a parent, you can use the `Mock` method `attach_mock`:

In [48]:
import datetime
m = Mock()
with patch('datetime.datetime') as child1:
 m.attach_mock(child1, 'datetime')
 print(datetime.datetime.now())
print(m.method_calls)
print(m.datetime.method_calls)


[call.datetime.now()]
[call.now()]


## Mocking/patching `async` methods
To test projects which work with `async` and `coroutines`, I recommend using `asynctest` library: https://pypi.python.org/pypi/asynctest/0.5.0

This library has a new mock object `CoroutineMock` which lets you define `return_value` and `side_effect` of your functions without having to worry about them being `async`, `Future` objects or anything.

In [49]:
import asynctest
from asynctest import patch, Mock
import asyncio

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

class AsyncThing:
 async def method(self):
 asyncio.sleep(50)
 return 5
 
 def normal_method(self):
 return 123

class TestSomething(asynctest.TestCase):
 use_default_loop = True
 async def test_something(self):
 a = AsyncThing()
 with patch('asyncio.sleep'):
 x = await a.method()
 assert x == 5
 
 async def test_class_mock(self):
 # Using `spec` makes it create CoroutineMock or MagicMock, depending on if the method is async or not.
 # Make sure to import patch from asynctest!
 with patch('__main__.AsyncThing', spec=AsyncThing) as asyncthing_mock:
 print(asyncthing_mock.method)
 print(asyncthing_mock.normal_method)

ts = TestSomething()
suite = asynctest.TestLoader().loadTestsFromModule(ts)
asynctest.TextTestRunner().run(suite)


..






----------------------------------------------------------------------
Ran 2 tests in 0.016s

OK




One missing feature from `asynctest` is the behaviour with `async with` (`__aenter__` and `__aexit__` methods). By default, using `asynctest` would not work.

Current "best known" solution is to create a new `Mock` class:

In [50]:
class AsyncContextManagerMock(MagicMock):
 def __init__(self, *args, **kwargs):
 super().__init__(*args, **kwargs)

 setattr(self, 'aenter_return', kwargs.get('aenter_return', MagicMock()))
 setattr(self, 'aexit_return', kwargs.get('aexit_return', None))

 async def __aenter__(self):
 return self.aenter_return

 async def __aexit__(self, exc_type, exc_value, traceback):
 return self.aexit_return

class TestSomething(asynctest.TestCase):
 use_default_loop = True
 async def test_async_with(self):
 async with AsyncContextManagerMock() as mock:
 print('first', mock)
 async with AsyncContextManagerMock(aenter_return=5) as value:
 print('second', value)

ts = TestSomething()
suite = asynctest.TestLoader().loadTestsFromModule(ts)
asynctest.TextTestRunner().run(suite)

.

first 
second 5



----------------------------------------------------------------------
Ran 1 test in 0.007s

OK


