Python Type Hints - How to Vary Return Type Based on an Argument

2021-09-06 So many ropes, so many overloads.

Here’s a recipe that combines typing.Literal with @overload to define a function that switches its return type based on the value of an argument.

ACME ID’s

Imagine we are writing a function to extract an identifier from a particular file format. Most files in this format use unicode identifiers, but in some cases they are raw bytes. For convenience we’d like to default to unicode treatment, but allow callers to retrieve raw bytes when required.

We could start with this definition:

from __future__ import annotations

from pathlib import Path


def extract_acme_id(path: Path, *, unicode: bool = True) -> str | bytes:
    # TODO: implementation
    ...

If unicode is True, we’ll return a str, otherwise we will return a bytes.

This initial definition works but is not so usable. Because return values could be str or bytes, call sites need to use type narrowing before using type specific operations.

We can fix this by telling the type checker about the correlation between the value of unicode and the return type.

(The eagle-eyed will note that we could also use two separate functions for str or bytes for this simple example, but that’s not always desirable. We’ll see more interesting examples from the standard library later.)

We can “expand” bool into Literal[True] and Literal[False] and treat those two cases in separate @overload definitions, giving us:

from __future__ import annotations

from pathlib import Path
from typing import Literal, overload


@overload
def extract_acme_id(path: Path, *, unicode: Literal[True] = True) -> str:
    ...


@overload
def extract_acme_id(path: Path, *, unicode: Literal[False]) -> bytes:
    ...


def extract_acme_id(path: Path, *, unicode: bool = True) -> str | bytes:
    # TODO: implementation
    ...

Using @overload is long-winded, but it does get the job done.

We can check our types with some calls and reveal_locals():

path = Path(__name__)

a = extract_acme_id(path)
b = extract_acme_id(path, unicode=True)
c = extract_acme_id(path, unicode=False)
reveal_locals()

Checking with Mypy:

$ mypy example.py
example.py:30: note: Revealed local types are:
example.py:30: note:     a: builtins.str
example.py:30: note:     b: builtins.str
example.py:30: note:     c: builtins.bytes
example.py:30: note:     path: pathlib.Path*

Each case has the expected return type.

With this set of @overload cases, Mypy will not allow callers to pass an arbitrary bool for the unicode argument. Calls must exactly pass True or False, or rather, any expression with type Literal[True] or Literal[False].

We may want this, as it forces call sites to be predictable. But if we want the flexibility of arbitrary bool arguments, we need to add one more overload definition:

@overload
def extract_acme_id(path: Path, *, unicode: bool) -> str | bytes:
    ...

This use the non-specific return type, so callers will need to use type narrowing.

We can check this using a random bool value:

import random

random_bool = random.random() < 0.5
d = extract_acme_id(path, unicode=random_bool)
reveal_type(d)

Running Mypy:

$ mypy example.py
example.py:36: note: Revealed type is "Union[builtins.str, builtins.bytes]"

Mypy shows that d has type str | bytes, in its long-form spelling.

Success!

Standard Library Examples

There are several examples of this pattern in the standard library, whose type hints live in the typeshed repository.

ast.parse()

In the ast (abstract syntax tree) module, ast.parse() parses Python source into an ast node, represented by an instance of ast.AST. The “exec” mode of ast.parse() returns an ast.Module, which is a specific subclass of ast.AST.

The definitions in typeshed capture this relationship using @overload with Literal(Source):

@overload
def parse(
    source: str | bytes,
    filename: str | bytes = ...,
    mode: Literal["exec"] = ...,
    *,
    type_comments: bool = ...,
    feature_version: None | int | tuple[int, int] = ...,
) -> Module:
    ...


@overload
def parse(
    source: str | bytes,
    filename: str | bytes = ...,
    mode: str = ...,
    *,
    type_comments: bool = ...,
    feature_version: None | int | tuple[int, int] = ...,
) -> AST:
    ...

subprocess.run()

In the subprocess module, subprocess.run() executes a process and returns a CompletedProcess instance. If the text argument is True, the output strings stored in the CompletedProcess will be strs, otherwise they will be byteses.

The definitions in typeshed use the @overload + Literal pattern, with the CompletedProcess class parametrized based on the string types (*Source):

@overload
def run(
    # ...
    text: Literal[True],
    # ...
) -> CompletedProcess[str]:
    ...


@overload
def run(
    # ...
    text: Literal[None, False] = ...,
    # ...
) -> CompletedProcess[bytes]:
    ...

(Other arguments and @overload variants removed for clarity.)

open()

The open() built-in takes a mode string, whose characters determine the way the file behaves. The returned file type varies based on the mode, for example using binary mode makes the file read as bytess rather than the default of strs.

The mode string alphabet is limited, yielding a manageable number of potential modes. typeshed spells out all the options, grouped within aliases (Source):

OpenTextModeUpdating = Literal[
    "r+",
    "+r",
    ...,
    "+tx",
]
OpenTextModeWriting = Literal["w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"]
OpenTextModeReading = Literal[
    "r", "rt", "tr", "U", "rU", "Ur", "rtU", "rUt", "Urt", "trU", "tUr", "Utr"
]
OpenTextMode = Union[OpenTextModeUpdating, OpenTextModeWriting, OpenTextModeReading]
OpenBinaryModeUpdating = Literal[
    "rb+",
    ...,
    "+bx",
]
OpenBinaryModeWriting = Literal["wb", "bw", "ab", "ba", "xb", "bx"]
OpenBinaryModeReading = Literal["rb", "br", "rbU", "rUb", "Urb", "brU", "bUr", "Ubr"]
OpenBinaryMode = Union[
    OpenBinaryModeUpdating, OpenBinaryModeReading, OpenBinaryModeWriting
]

These literals are then used for the definitions of open():

@overload
def open(
    file: _OpenFile,
    mode: OpenTextMode = ...,
    # ...
) -> TextIOWrapper:
    ...


# Unbuffered binary mode: returns a FileIO
@overload
def open(
    file: _OpenFile,
    # ...
) -> FileIO:
    ...

Fin

I literally hope this post has not overloaded you,

—Adam


🎉 My book Speed Up Your Django Tests is now up to date for Django 3.2. 🎉
Buy now on Gumroad


Subscribe via RSS, Twitter, or email:

One summary email a week, no spam, I pinky promise.

Related posts:

Tags: mypy, python