Shape typing in Python

While I was looking the other way, Python got advanced static types! Here’s matrix multiplication, describing the input shapes and its output shape:

def mat_mul[
    N, K, M
  m1: Mat[N, M],
  m2: Mat[M, K],
) -> Mat[N, K]:
    return m1 @ m2

There’s a lot going on here! In traditional Python, we’d write:

def mat_mul(m1, m2):
    return m1 @ m2

Then if we used the wrong shapes, we’d get a runtime error, like this:

>>> m1 @ m2
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: matmul: Input operand 1 has a mismatch
  in its core dimension 0, with gufunc signature
  (n?,k),(k,m?)->(n?,m?) (size 2 is different from 3)

Our type-safe wrapper mat_mul uses a type Mat[N, M], which I defined as:

type Mat[N, M] = np.ndarray[
    tuple[N, M],

If we try to multiply matrices of the wrong shape, Pyright gives a type error.

This uses Numpy’s np.ndarray type, which takes two arguments that describe the shape and dtype. For example, we can describe a 2x3 matrix of integers as:

mat2x3: np.ndarray[
    tuple[Literal[2], Literal[3]],
] = np.array([[1,2,3],[4,5,6]])

At the moment, most of the numpy API does not use these type parameters. For example, np.array(...) just gives you an np.ndarray[Any, Any]. So we have to make our own type-safe wrappers.

