Python Type Hints - How to Vary Return Type Based on an Argument2021-09-06
Here’s a recipe that combines
@overload to define a function that switches its return type based on the value of an argument.
Imagine we are writing a function to extract an identifier from a particular file format. Most files in this format use unicode identifiers, but in some cases they are raw bytes. For convenience we’d like to default to unicode treatment, but allow callers to retrieve raw bytes when required.
We could start with this definition:
from __future__ import annotations from pathlib import Path def extract_acme_id(path: Path, *, unicode: bool = True) -> str | bytes: # TODO: implementation ...
True, we’ll return a
str, otherwise we will return a
This initial definition works but is not so usable.
Because return values could be
bytes, call sites need to use type narrowing before using type specific operations.
We can fix this by telling the type checker about the correlation between the value of
unicode and the return type.
(The eagle-eyed will note that we could also use two separate functions for
bytes for this simple example, but that’s not always desirable.
We’ll see more interesting examples from the standard library later.)
We can “expand”
Literal[False] and treat those two cases in separate
@overload definitions, giving us:
from __future__ import annotations from pathlib import Path from typing import Literal, overload @overload def extract_acme_id(path: Path, *, unicode: Literal[True] = True) -> str: ... @overload def extract_acme_id(path: Path, *, unicode: Literal[False]) -> bytes: ... def extract_acme_id(path: Path, *, unicode: bool = True) -> str | bytes: # TODO: implementation ...
@overload is long-winded, but it does get the job done.
We can check our types with some calls and
path = Path(__name__) a = extract_acme_id(path) b = extract_acme_id(path, unicode=True) c = extract_acme_id(path, unicode=False) reveal_locals()
Checking with Mypy:
$ mypy example.py example.py:30: note: Revealed local types are: example.py:30: note: a: builtins.str example.py:30: note: b: builtins.str example.py:30: note: c: builtins.bytes example.py:30: note: path: pathlib.Path*
Each case has the expected return type.
With this set of
@overload cases, Mypy will not allow callers to pass an arbitrary
bool for the
Calls must exactly pass
False, or rather, any expression with type
We may want this, as it forces call sites to be predictable.
But if we want the flexibility of arbitrary
bool arguments, we need to add one more overload definition:
@overload def extract_acme_id(path: Path, *, unicode: bool) -> str | bytes: ...
This use the non-specific return type, so callers will need to use type narrowing.
We can check this using a random
import random random_bool = random.random() < 0.5 d = extract_acme_id(path, unicode=random_bool) reveal_type(d)
$ mypy example.py example.py:36: note: Revealed type is "Union[builtins.str, builtins.bytes]"
Mypy shows that
d has type
str | bytes, in its long-form spelling.
Standard Library Examples¶
There are several examples of this pattern in the standard library, whose type hints live in the typeshed repository.
ast (abstract syntax tree) module,
ast.parse() parses Python source into an ast node, represented by an instance of
The “exec” mode of
ast.parse() returns an
ast.Module, which is a specific subclass of
The definitions in typeshed capture this relationship using
@overload def parse( source: str | bytes, filename: str | bytes = ..., mode: Literal["exec"] = ..., *, type_comments: bool = ..., feature_version: None | int | tuple[int, int] = ..., ) -> Module: ... @overload def parse( source: str | bytes, filename: str | bytes = ..., mode: str = ..., *, type_comments: bool = ..., feature_version: None | int | tuple[int, int] = ..., ) -> AST: ...
subprocess.run() executes a process and returns a
text argument is
True, the output strings stored in the
CompletedProcess will be
strs, otherwise they will be
The definitions in typeshed use the
Literal pattern, with the
CompletedProcess class parametrized based on the string types (*Source):
@overload def run( # ... text: Literal[True], # ... ) -> CompletedProcess[str]: ... @overload def run( # ... text: Literal[None, False] = ..., # ... ) -> CompletedProcess[bytes]: ...
(Other arguments and
@overload variants removed for clarity.)
open() built-in takes a mode string, whose characters determine the way the file behaves.
The returned file type varies based on the mode, for example using binary mode makes the file read as
bytess rather than the default of
The mode string alphabet is limited, yielding a manageable number of potential modes. typeshed spells out all the options, grouped within aliases (Source):
OpenTextModeUpdating = Literal[ "r+", "+r", ..., "+tx", ] OpenTextModeWriting = Literal["w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"] OpenTextModeReading = Literal[ "r", "rt", "tr", "U", "rU", "Ur", "rtU", "rUt", "Urt", "trU", "tUr", "Utr" ] OpenTextMode = Union[OpenTextModeUpdating, OpenTextModeWriting, OpenTextModeReading] OpenBinaryModeUpdating = Literal[ "rb+", ..., "+bx", ] OpenBinaryModeWriting = Literal["wb", "bw", "ab", "ba", "xb", "bx"] OpenBinaryModeReading = Literal["rb", "br", "rbU", "rUb", "Urb", "brU", "bUr", "Ubr"] OpenBinaryMode = Union[ OpenBinaryModeUpdating, OpenBinaryModeReading, OpenBinaryModeWriting ]
These literals are then used for the definitions of
@overload def open( file: _OpenFile, mode: OpenTextMode = ..., # ... ) -> TextIOWrapper: ... # Unbuffered binary mode: returns a FileIO @overload def open( file: _OpenFile, # ... ) -> FileIO: ...
I literally hope this post has not overloaded you,
🎉 My book Speed Up Your Django Tests is now up to date for Django 3.2. 🎉
Buy now on Gumroad
One summary email a week, no spam, I pinky promise.
Tags: mypy, python
© 2021 All rights reserved.