Module pggm_datalab_utils.flatten

Expand source code
from typing import Iterable, Any
from itertools import accumulate
import operator as op


def flatten(it: Iterable[Iterable[Any]]) -> list[Any]:
    """
    Flatten a list of lists, or any kind of nested iterable, into a single list.
    """
    return [x for item in it for x in item]


def flatten_with_mapping(it: Iterable[list[Any]]) -> (list[Any], list[int]):
    """
    Flatten a list of lists, or any kind of nested iterable, into a single list, additionally returning the lengths
    of each of the items, allowing you to restore the original list of lists using `unflatten`.
    We test that `unflatten(*flatten_with_mapping(x)) == x` and vice versa.
    """
    return [x for item in it for x in item], [len(item) for item in it]


def unflatten(flattened: list[Any], mapping: list[int]) -> list[list[Any]]:
    """
    Unflatten a list, breaking it into sublists of length indicated by `mapping`.
    We test that `unflatten(*flatten_with_mapping(x)) == x` and vice versa.
    """
    indices = list(accumulate(mapping, op.add))
    return [flattened[start:end] for start, end in zip([0] + indices[:-1], indices)]

Functions

def flatten(it: Iterable[Iterable[Any]]) ‑> list[typing.Any]

Flatten a list of lists, or any kind of nested iterable, into a single list.

Expand source code
def flatten(it: Iterable[Iterable[Any]]) -> list[Any]:
    """
    Flatten a list of lists, or any kind of nested iterable, into a single list.
    """
    return [x for item in it for x in item]
def flatten_with_mapping(it: Iterable[list[Any]]) ‑> (list[typing.Any], list[int])

Flatten a list of lists, or any kind of nested iterable, into a single list, additionally returning the lengths of each of the items, allowing you to restore the original list of lists using unflatten(). We test that unflatten(*flatten_with_mapping(x)) == x and vice versa.

Expand source code
def flatten_with_mapping(it: Iterable[list[Any]]) -> (list[Any], list[int]):
    """
    Flatten a list of lists, or any kind of nested iterable, into a single list, additionally returning the lengths
    of each of the items, allowing you to restore the original list of lists using `unflatten`.
    We test that `unflatten(*flatten_with_mapping(x)) == x` and vice versa.
    """
    return [x for item in it for x in item], [len(item) for item in it]
def unflatten(flattened: list[typing.Any], mapping: list[int]) ‑> list[list[typing.Any]]

Unflatten a list, breaking it into sublists of length indicated by mapping. We test that unflatten(*flatten_with_mapping(x)) == x and vice versa.

Expand source code
def unflatten(flattened: list[Any], mapping: list[int]) -> list[list[Any]]:
    """
    Unflatten a list, breaking it into sublists of length indicated by `mapping`.
    We test that `unflatten(*flatten_with_mapping(x)) == x` and vice versa.
    """
    indices = list(accumulate(mapping, op.add))
    return [flattened[start:end] for start, end in zip([0] + indices[:-1], indices)]