Simple In-Memory Caching of Django Model Data With cachetools

Them’s the United States of America!

A client project recently was suffering from an N+1 queries problem in a complicated Django admin page. Many measures had already been taken to prevent N+1 queries, such as use of django-auto-prefetch and some manually tuned select_related() / prefetch_related() calls. The remaining N+1 in question was a bit resistant to those methods because it came through several layers of admin code and many-to-many fields, making it harder than normal to find the place to modify the QuerySet construction.

Instead of spelunking through those layers of code to fix this N+1, I took a “shortcut” with some local in-memory caching. This was possible because of the particular model.

The model in question represents a county in the United States:

class County(auto_prefetch.Model):
    state = auto_prefetch.ForeignKey(State, on_delete=models.CASCADE)
    name = models.CharField(max_length=30)

    ...

    def __str__(self):
        return f"{self.name}-{self.state.abbreviation}"

The N queries came from the __str__() method, which is used to display a given County in the admin. A list of N counties was selected from a related model without using select_related() or prefetch_related() to select the related states at the same time. Therefore Django performed an extra query to select the state when calling __str__() on each county. The page in question displayed a list of many counties, so these queries added up to a few seconds of extra processing.

The problematic data - the abbreviations for all the states, along with the state ID’s - is a notably good candidate for local in-memory caching. It’s small - an integer and a two character string per state. And it changes very infrequently - the last state admitted was Hawaii, in 1959. This lead me to think of using in-memory caching, rather than alternatives such as Django’s cache framework.

To build the cache I used the @ttl_cache (time-to-live cache) decorator from the versatile cachetools library. This wraps a function and caches its results by the given arguments, and expires those cached values after a timeout. If the function takes no arguments, only one value is cached.

I built a caching function like so:

@ttl_cache(maxsize=1)
def _get_state_abbrevations():
    return {
        id_: abbreviation
        for id_, abbreviation in State.objects.values_list("id", "abbreviation")
    }

The first time _get_state_abbrevations() is called, it queries the database and returns a dictionary mapping State ID’s to their abbreviations. The query takes only a few milliseconds.

On later calls, until the time-to-live expiry time is reached, the function immediately returns its in-memory cached value. This is almost instant.

On calls after the expiry time, the cache is ignored, the database re-queried, and the result again cached for the time-to-live period.

The default TTL for @ttl_cache is 600 seconds, or ten minutes, which is just fine for a query that’s already very fast, on data that changes very rarely. I could then modify the model to use the cached results to form its __str__():

class County(auto_prefetch.Model):
    ...

    def __str__(self):
        return f"{self.name}-{get_state_abbrevations()[self.state_id]}"

Tests

This worked well and fixed the N+1 Queries issue, but it broke tests hitting the County.__str__() method. The cache’s foundational assumption—that the set of States rarely changes—is broken by the test suite. Each test case creates its own test data, so a given State, e.g. Ohio, may be created many times during a test run. Each time the State is created, the database assigns it a different auto-increment ID, which will be missing from previously cached data.

The solution I came up with here was to wrap cache access with a function that clears the cache if a given state ID is not found:

def get_state_abbrevation(state_id):
    try:
        return _get_state_abbreviations()[state_id]
    except KeyError:
        pass
    _get_state_abbreviations.cache_clear()
    return _get_state_abbreviations()[state_id]

And then use that function in the __str__() method:

class County(auto_prefetch.Model):
    ...

    def __str__(self):
        return f"{self.name}-{get_state_abbrevation(self.state_id)}"

This also guards against any unexpected situations that might occur on the server, such as the state data getting reloaded with different ID’s.

Since this project requires 100% code coverage I added a test case to exercise all pathways in get_state_abbrevation():

from django.test import TestCase

from example.core.models import State, get_state_abbrevation


class GetStateAbbrevationTests(TestCase):
    def test_get_empty(self):
        with pytest.raises(KeyError):
            get_state_abbrevation(1)

    def test_get_exists(self):
        state = State.objects.create(name="Test State", abbreviation="TS")

        result = get_state_abbrevation(state.id)

        assert result == "TS"

    def test_get_cached(self):
        state = State.objects.create(name="Test State", abbreviation="TS")
        get_state_abbrevation(state.id)

        result = get_state_abbrevation(state.id)

        assert result == "TS"

    def test_get_uncached(self):
        state = State.objects.create(name="Test State", abbreviation="TS")
        get_state_abbrevation(state.id)
        state2 = State.objects.create(name="Test State 2", abbreviation="TT")

        result = get_state_abbrevation(state2.id)

        assert result == "TT"

Great!

Fin

Cache me if you can,

—Adam


Improve your Django develompent experience with my new book.


Subscribe via RSS, Twitter, Mastodon, or email:

One summary email a week, no spam, I pinky promise.

Related posts:

Tags: