Simple In-Memory Caching of Django Model Data With cachetools

2021-01-21 Them’s the United States of America!

A client project recently was suffering from an N+1 queries problem in a complicated Django admin page. Many measures had already been taken to prevent N+1 queries, such as use of django-auto-prefetch and some manually tuned select_related() / prefetch_related() calls. The remaining N+1 in question was a bit resistant to those methods because it came through several layers of admin code and many-to-many fields, making it harder than normal to find the place to modify the QuerySet construction.

Instead of spelunking through those layers of code to fix this N+1, I took a “shortcut” with some local in-memory caching. This was possible because of the particular model.

The model in question represents a county in the United States:

class County(auto_prefetch.Model):
    state = auto_prefetch.ForeignKey(State, on_delete=models.CASCADE)
    name = models.CharField(max_length=30)

    ...

    def __str__(self):
        return f"{self.name}-{self.state.abbreviation}"

The N queries came from the __str__() method, which is used to display a given County in the admin. A list of N counties was selected from a related model without using select_related() or prefetch_related() to select the related states at the same time. Therefore Django performed an extra query to select the state when calling __str__() on each county. The page in question displayed a list of many counties, so these queries added up to a few seconds of extra processing.

The problematic data - the abbreviations for all the states, along with the state ID’s - is a notably good candidate for local in-memory caching. It’s small - an integer and a two character string per state. And it changes very infrequently - the last state admitted was Hawaii, in 1959. This lead me to think of using in-memory caching, rather than alternatives such as Django’s cache framework.

To build the cache I used the @ttl_cache (time-to-live cache) decorator from the versatile cachetools library. This wraps a function and caches its results by the given arguments, and expires those cached values after a timeout. If the function takes no arguments, only one value is cached.

I built a caching function like so:

@ttl_cache(maxsize=1)
def _get_state_abbrevations():
    return {
        id_: abbreviation
        for id_, abbreviation in State.objects.values_list("id", "abbreviation")
    }

The first time _get_state_abbrevations() is called, it queries the database and returns a dictionary mapping State ID’s to their abbreviations. The query takes only a few milliseconds.

On later calls, until the time-to-live expiry time is reached, the function immediately returns its in-memory cached value. This is almost instant.

On calls after the expiry time, the cache is ignored, the database re-queried, and the result again cached for the time-to-live period.

The default TTL for @ttl_cache is 600 seconds, or ten minutes, which is just fine for a query that’s already very fast, on data that changes very rarely. I could then modify the model to use the cached results to form its __str__():

class County(auto_prefetch.Model):
    ...

    def __str__(self):
        return f"{self.name}-{get_state_abbrevations()[self.state_id]}"

Tests

This worked well and fixed the N+1 Queries issue, but it broke tests hitting the County.__str__() method. The cache’s foundational assumption—that the set of States rarely changes—is broken by the test suite. Each test case creates its own test data, so a given State, e.g. Ohio, may be created many times during a test run. Each time the State is created, the database assigns it a different auto-increment ID, which will be missing from previously cached data.

The solution I came up with here was to wrap cache access with a function that clears the cache if a given state ID is not found:

def get_state_abbrevation(state_id):
    try:
        return _get_state_abbreviations()[state_id]
    except KeyError:
        pass
    _get_state_abbreviations.cache_clear()
    return _get_state_abbreviations()[state_id]

And then use that function in the __str__() method:

class County(auto_prefetch.Model):
    ...

    def __str__(self):
        return f"{self.name}-{get_state_abbrevation(self.state_id)}"

This also guards against any unexpected situations that might occur on the server, such as the state data getting reloaded with different ID’s.

Since this project requires 100% code coverage I added a test case to exercise all pathways in get_state_abbrevation():

from django.test import TestCase

from example.core.models import State, get_state_abbrevation


class GetStateAbbrevationTests(TestCase):
    def test_get_empty(self):
        with pytest.raises(KeyError):
            get_state_abbrevation(1)

    def test_get_exists(self):
        state = State.objects.create(name="Test State", abbreviation="TS")

        result = get_state_abbrevation(state.id)

        assert result == "TS"

    def test_get_cached(self):
        state = State.objects.create(name="Test State", abbreviation="TS")
        get_state_abbrevation(state.id)

        result = get_state_abbrevation(state.id)

        assert result == "TS"

    def test_get_uncached(self):
        state = State.objects.create(name="Test State", abbreviation="TS")
        get_state_abbrevation(state.id)
        state2 = State.objects.create(name="Test State 2", abbreviation="TT")

        result = get_state_abbrevation(state2.id)

        assert result == "TT"

Great!

Fin

Cache me if you can,

—Adam


Want better tests? Check out my book Speed Up Your Django Tests which teaches you to write faster, more accurate tests.


Subscribe via RSS, Twitter, or email:

One summary email a week, no spam, I pinky promise.

Related posts:

Tags: django