patch
allows you to temporarily replace real objects in your test suite. The API, however, is confusing. I’ll go over some of the basic usages when using patch
.
First, here’s the set of files that I’ll be referencing throughout this post:
product.py
import membership
def discounted_price_for_user(user_id, original_price=10):
return original_price - membership.user_discount(user_id)
membership.py
def user_discount(user_id):
raise RuntimeError("unwanted sideeffect")
discounted_price_for_user
inside the product
module is the System Under Test (SUT) and it collaborates with or depends on user_discount
from the membership
module. I’ve intentionally raised an exception inside user_discount
in order to simulate an unwanted side-effect. It stands for anything you dont want to actually run in a unittest and will be the target for our patching.
Approaches
There are two common approaches of using the patch
api that serve the same purpose:
- Decorator
- Context Manager
The core tasks of both are to set up new object at the beginning of the test and then restore original object at the end of the test.
Decorator
from unittest.mock import patch
import unittest
import product
class TestProduct(unittest.TestCase):
@patch("membership.user_discount")
def test_product_price(self, mock_user_discount):
discount_amt = 5
mock_user_discount.return_value = discount_amt
self.assertEqual(product.discounted_price_for_user(1, original_price=10), 5)
patch
accepts a path to the object where it is looked up. Since the function user_discount
is looked up in the membership
module in the product
module, the full path is membership.user_discount
. The decorator converts that string into a object lookup that will find and replace the exact reference used by your SUT.
If it can’t find it, you’ll get a target not found error.
Instead of providing a string only, you can also supply the object that contains the name you’re replacing:
class TestProduct(unittest.TestCase):
@patch.object(membership, "user_discount")
def test_product_price(self, mock_user_discount):
mock_user_discount.return_value = 5
self.assertEqual(product.discounted_price_for_user(1, original_price=10), 5)
In this case, we need to make sure membership
is imported into the test suite so that it can be referenced directly.
Right now, the mock object is being created and passed in as an argument. This can get really hard to read - one way around it is to supply the mock objects explicitly in the call to patch
.
class TestProduct(unittest.TestCase):
@patch("membership.user_discount", Mock(return_value=5))
def test_product_price(self):
self.assertEqual(product.discounted_price_for_user(1, original_price=10), 5)
#1 Patch Gotcha
What happens when you change product.py
to look like:
from membership import user_discount
def discounted_price_for_user(user_id, original_price=10):
return original_price - user_discount(user_id)
Instead of importing membership
into the product namespace and looking up user_discount
in membership
, we’re importing user_discount
directly into the product
namespace.
Now if I run the test, boom:
ERROR: test_product_price (__main__.TestProduct)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/alin/.pyenv/pyenv/versions/3.8.3/lib/python3.8/unittest/mock.py", line 1325, in patched
return func(*newargs, **newkeywargs)
File "test.py", line 10, in test_product_price
self.assertEqual(product.discounted_price_for_user(1, original_price=10), 5)
File "/home/alin/code/sandbox/patch-examples/product.py", line 4, in discounted_price_for_user
return original_price - user_discount(user_id)
File "/home/alin/code/sandbox/patch-examples/membership.py", line 2, in user_discount
raise RuntimeError("unwanted sideeffect")
RuntimeError: unwanted sideeffect
Uh oh. What happened? The python patch documentation offers a great explanation for why this occurs. If you don’t understand this error, go read that section.
Now lets fix this:
class TestProduct(unittest.TestCase):
@patch("product.user_discount")
def test_product_price(self, mock_user_discount):
mock_user_discount.return_value = 5
self.assertEqual(product.discounted_price_for_user(1, original_price=10), 5)
Since we’re now looking up user_discount
inside the product
module (even though it’s defined inside membership), we need to provide patch with the path to the function we’re replacing inside product
. This is a fundamental behavior of patch
that will cause you hours of headaches if you don’t internalize it, unfortunately.
Context Manager
Nearly everything I coverered using the decorator approach also applies to context managers. Here’s how to use patch
as a context manager:
from unittest.mock import patch, Mock
import unittest
import product
import membership
class TestProduct(unittest.TestCase):
def test_product_price(self):
with patch("product.user_discount") as mock_user_discount:
mock_user_discount.return_value = 5
self.assertEqual(product.discounted_price_for_user(1, original_price=10), 5)
Just like decorators, we can use patch.object
as well or even provide an mock with a return object upfront. I suggest you play around with it.
One apparent drawback that you’ll run into with the context manager approach is when you attempt to patch multiple things. Pretend for a moment that you have a function holiday_discount
that’s defind in the same module as user_discount
. Just like user_discount
, you want to mock it out because it carries unwanted side-effects.
To patch both functions, you might end up with:
class TestProduct(unittest.TestCase):
def test_product_price(self):
with patch("product.user_discount", Mock(return_value=5)):
with patch("product.holiday_discount", Mock(return_value=5)):
self.assertEqual(product.discounted_price_for_user(1, original_price=10), 0)
Welcome to nesting hell. There’s two easy ways out of this:
Comma separate your patch calls:
class TestProduct(unittest.TestCase):
def test_product_price(self):
with patch("product.user_discount", Mock(return_value=5)), patch("product.holiday_discount", Mock(return_value=5)):
self.assertEqual(product.discounted_price_for_user(1, original_price=10), 0)
Or use patch.multiple
(you can also use this with the decorator approach):
class TestProduct(unittest.TestCase):
def test_product_price(self):
with patch.multiple("product", user_discount=Mock(return_value=5), holiday_discount=Mock(return_value=5)):
self.assertEqual(product.discounted_price_for_user(1, original_price=10), 0)
Overall, I strongly prefer using context managers because I prefer reading statements from top to bottom and seeing mocks close to where they’re used. The decorator approach, however, forces you to look at the details of the test before you’ve even to even read the name of the test.