Simple In-Memory Caching of Django Model Data With cachetools

A client project recently was suffering from an N+1 queries problem in a complicated Django admin page. Many measures had already been taken to prevent N+1 queries, such as use of django-auto-prefetch and some manually tuned select_related()
/ prefetch_related()
calls. The remaining N+1 in question was a bit resistant to those methods because it came through several layers of admin code and many-to-many fields, making it harder than normal to find the place to modify the QuerySet
construction.
Instead of spelunking through those layers of code to fix this N+1, I took a “shortcut” with some local in-memory caching. This was possible because of the particular model.
The model in question represents a county in the United States:
class County(auto_prefetch.Model):
state = auto_prefetch.ForeignKey(State, on_delete=models.CASCADE)
name = models.CharField(max_length=30)
...
def __str__(self):
return f"{self.name}-{self.state.abbreviation}"
The N queries came from the __str__()
method, which is used to display a given County in the admin. A list of N counties was selected from a related model without using select_related()
or prefetch_related()
to select the related states at the same time. Therefore Django performed an extra query to select the state when calling __str__()
on each county. The page in question displayed a list of many counties, so these queries added up to a few seconds of extra processing.
The problematic data - the abbreviations for all the states, along with the state ID’s - is a notably good candidate for local in-memory caching. It’s small - an integer and a two character string per state. And it changes very infrequently - the last state admitted was Hawaii, in 1959. This lead me to think of using in-memory caching, rather than alternatives such as Django’s cache framework.
To build the cache I used the @ttl_cache
(time-to-live cache) decorator from the versatile cachetools library. This wraps a function and caches its results by the given arguments, and expires those cached values after a timeout. If the function takes no arguments, only one value is cached.
I built a caching function like so:
from cachetools.func import ttl_cache
@ttl_cache(maxsize=1)
def _get_state_abbrevations():
return {
id_: abbreviation
for id_, abbreviation in State.objects.values_list("id", "abbreviation")
}
The first time _get_state_abbrevations()
is called, it queries the database and returns a dictionary mapping State
ID’s to their abbreviations. The query takes only a few milliseconds.
On later calls, until the time-to-live expiry time is reached, the function immediately returns its in-memory cached value. This is almost instant.
On calls after the expiry time, the cache is ignored, the database re-queried, and the result again cached for the time-to-live period.
The default TTL for @ttl_cache
is 600 seconds, or ten minutes, which is just fine for a query that’s already very fast, on data that changes very rarely. I could then modify the model to use the cached results to form its __str__()
:
class County(auto_prefetch.Model):
...
def __str__(self):
return f"{self.name}-{get_state_abbrevations()[self.state_id]}"
Tests
This worked well and fixed the N+1 Queries issue, but it broke tests hitting the County.__str__()
method. The cache’s foundational assumption—that the set of States rarely changes—is broken by the test suite. Each test case creates its own test data, so a given State
, e.g. Ohio, may be created many times during a test run. Each time the State
is created, the database assigns it a different auto-increment ID, which will be missing from previously cached data.
The solution I came up with here was to wrap cache access with a function that clears the cache if a given state ID is not found:
def get_state_abbrevation(state_id):
try:
return _get_state_abbreviations()[state_id]
except KeyError:
pass
_get_state_abbreviations.cache_clear()
return _get_state_abbreviations()[state_id]
And then use that function in the __str__()
method:
class County(auto_prefetch.Model):
...
def __str__(self):
return f"{self.name}-{get_state_abbrevation(self.state_id)}"
This also guards against any unexpected situations that might occur on the server, such as the state data getting reloaded with different ID’s.
Since this project requires 100% code coverage I added a test case to exercise all pathways in get_state_abbrevation()
:
from django.test import TestCase
from example.core.models import State, get_state_abbrevation
class GetStateAbbrevationTests(TestCase):
def test_get_empty(self):
with pytest.raises(KeyError):
get_state_abbrevation(1)
def test_get_exists(self):
state = State.objects.create(name="Test State", abbreviation="TS")
result = get_state_abbrevation(state.id)
assert result == "TS"
def test_get_cached(self):
state = State.objects.create(name="Test State", abbreviation="TS")
get_state_abbrevation(state.id)
result = get_state_abbrevation(state.id)
assert result == "TS"
def test_get_uncached(self):
state = State.objects.create(name="Test State", abbreviation="TS")
get_state_abbrevation(state.id)
state2 = State.objects.create(name="Test State 2", abbreviation="TT")
result = get_state_abbrevation(state2.id)
assert result == "TT"
Great!
If your Django project’s long test runs bore you, I wrote a book that can help.
One summary email a week, no spam, I pinky promise.
Related posts:
Tags: django