Python Type Hints - How to Type a Context Manager

2021-07-04 We’ll need a whole box of type(ography) tools...

Python’s context manager protocol has only two methods, with straightforward types. But when it comes to adding accurate type hints to a context manager, we still need to combine several typing features. Let’s look at how we can do this for the two different ways of making a context manager.

@contextmanager hints

The easiest way to create a context manager is using the @contextmanager decorator from contextlib. When adding type hints, this remains the easiest way, as we only need to type the underlying generator. For example:

from contextlib import contextmanager
from collections.abc import Generator


@contextmanager
def my_context_manager() -> Generator[None, None, None]:
    yield

Note: using collections.abc.Generator is only supported on Python 3.9; on older versions we need to import typing.Generator instead.

We use Generator to specify the three types our generator uses - the yield type, send type, and return type. In this case we do not yield a value, so we set the yield type to None. @contextmanager never sends anything into our generator, and ignores our return value, so we should always use None for the second and third values.

For context managers that return values, we would swap the first None for the value’s type, for example:

from contextlib import contextmanager
from collections.abc import Generator


@contextmanager
def dice_roll() -> Generator[int, None, None]:
    yield 4

Note: the documentation for typing.Generator notes simple generators can use e.g. -> Iterator[int]. This is currently accepted by the @contextmanager type hints in typeshed, but there’s an open issue showing how this can lead to bugs. Therefore it’s best to stick to Generator. Thanks to Tom Grainger for pointing this out on Twitter, and Anthony Sottile for reporting the issue.

Class-based context managers

We can create more complicated context managers as classes, and this is where the type hints need a bit more work. The simplest definition looks like this:

from __future__ import annotations

from types import TracebackType


class MyContextManager:
    def __enter__(self) -> None:
        pass

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        pass

For context managers that return values, we would swap __enter__’s return type from None to the value’s type.

For context managers that suppress exceptions, we would change __exit__’s return type to bool.

Our __exit__ method’s type hints say:

These are all true, but the type hints don’t represent the correlation between the variables: they’re either all set, or all None. __exit__ can never be called with only some values not None.

If we only care about using this context manager with the with statement, we can handle this correlation as needed with type narrowing inside our __exit__ method’s body. But if we care about users calling __exit__ directly, we can reach for @overload. Let’s look at these two techniques in turn.

Type narrowing in __exit__

To handle the correlation inside we add some type narrowing with if and assert. If we wanted to handle exceptions, we could use a body like this:

if exc_type is not None:
    ...
else:
    ...

Mypy can use the if statement to infer the type of exc_type in both blocks: it’s type[BaseException] in the if block, and None in the else block. But because Mypy doesn’t know about the correlation between the variables, it can’t narrow the types of exc_val or exc_tb. We can tell help Mypy narrow the types with assert statements:

if exc_type is not None:
    assert exc_val is not None
    assert exc_tb is not None
    ...
else:
    assert exc_val is None
    assert exc_tb is None
    ...

Mypy can read the asserts and determine the variables’ types in the following lines.

The above example contains the complete set of assert statements, but we don’t always need to be so exhaustive. If a block doesn’t use a variable, we don’t need to narrow its type there. For example, many context managers only use the exception value, which we can do without mentioning the other variables:

if exc_val is not None:
    ...  # do something with only exc_val
else:
    ...

If in doubt about the type narrowing you need, debug with reveal_type() or reveal_locals().

Using @overload for __exit__

To make the correlation between __exit__’s arguments visible to callers, we need to use @typing.overload to list the accepted forms. This requires a couple extra stub functions:

from __future__ import annotations

from typing import overload
from types import TracebackType


class MyContextManager:
    def __enter__(self) -> None:
        pass

    @overload
    def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
        ...

    @overload
    def __exit__(
        self,
        exc_type: type[BaseException],
        exc_val: BaseException,
        exc_tb: TracebackType,
    ) -> None:
        ...

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        pass

The first two @overload-decorated functions declare the allowed types for callers. We spell out the two cases: either all the arguments are None, or all the arguments are set.

The final __exit__ function is the implementation, and here we need to combine the overloaded types. Note that, since we have to use unions, inside the body we still need to use type narrowing as above. (Mypy can’t propagate the @overload information into the body at current.)

Now if callers try to pass an incomplete set of arguments, they will get a type error. For example, if we wrote a call like this:

MyContextManager().__exit__(ValueError, None, None)

Then Mypy would complain like so:

$ mypy --strict example.py
example.py:40: error: No overload variant of "__exit__" of "MyContextManager" matches argument types "Type[ValueError]", "None", "None"
example.py:40: note: Possible overload variants:
example.py:40: note:     def __exit__(self, None, None, None) -> None
example.py:40: note:     def __exit__(self, Type[BaseException], BaseException, TracebackType) -> None
Found 1 error in 1 file (checked 1 source file)

Fin

I hope this has managed to give you some context,

—Adam


🎉 My book Speed Up Your Django Tests is now up to date for Django 3.2. 🎉
Buy now on Gumroad


Subscribe via RSS, Twitter, or email:

One summary email a week, no spam, I pinky promise.

Related posts:

Tags: mypy, python