docs for muutils v0.8.3
View Source on GitHub

muutils.mlutils

miscellaneous utilities for ML pipelines


  1"miscellaneous utilities for ML pipelines"
  2
  3from __future__ import annotations
  4
  5import json
  6import os
  7import random
  8import typing
  9import warnings
 10from itertools import islice
 11from pathlib import Path
 12from typing import Any, Callable, Optional, TypeVar, Union
 13
 14ARRAY_IMPORTS: bool
 15try:
 16    import numpy as np
 17    import torch
 18
 19    ARRAY_IMPORTS = True
 20except ImportError as e:
 21    warnings.warn(
 22        f"Numpy or torch not installed. Array operations will not be available.\n{e}"
 23    )
 24    ARRAY_IMPORTS = False
 25
 26DEFAULT_SEED: int = 42
 27GLOBAL_SEED: int = DEFAULT_SEED
 28
 29
 30def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device":
 31    """Get the torch.device instance on which `torch.Tensor`s should be allocated."""
 32    if not ARRAY_IMPORTS:
 33        raise ImportError(
 34            "Numpy or torch not installed. Array operations will not be available."
 35        )
 36    try:
 37        # if device is given
 38        if device is not None:
 39            device = torch.device(device)
 40            if any(
 41                [
 42                    torch.cuda.is_available() and device.type == "cuda",
 43                    torch.backends.mps.is_available() and device.type == "mps",
 44                    device.type == "cpu",
 45                ]
 46            ):
 47                # if device is given and available
 48                pass
 49            else:
 50                warnings.warn(
 51                    f"Specified device {device} is not available, falling back to CPU"
 52                )
 53                return torch.device("cpu")
 54
 55        # no device given, infer from availability
 56        else:
 57            if torch.cuda.is_available():
 58                device = torch.device("cuda")
 59            elif torch.backends.mps.is_available():
 60                device = torch.device("mps")
 61            else:
 62                device = torch.device("cpu")
 63
 64        # put a dummy tensor on the device to check if it is available
 65        _dummy = torch.zeros(1, device=device)
 66
 67        return device
 68
 69    except Exception as e:
 70        warnings.warn(
 71            f"Error while getting device, falling back to CPU. Error: {e}",
 72            RuntimeWarning,
 73        )
 74        return torch.device("cpu")
 75
 76
 77def set_reproducibility(seed: int = DEFAULT_SEED):
 78    """
 79    Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
 80
 81    Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades
 82    off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
 83    """
 84    global GLOBAL_SEED
 85
 86    GLOBAL_SEED = seed
 87
 88    random.seed(seed)
 89
 90    if ARRAY_IMPORTS:
 91        np.random.seed(seed)
 92        torch.manual_seed(seed)
 93
 94        torch.use_deterministic_algorithms(True)
 95        # Ensure reproducibility for concurrent CUDA streams
 96        # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility.
 97        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
 98
 99
100def chunks(it, chunk_size):
101    """Yield successive chunks from an iterator."""
102    # https://stackoverflow.com/a/61435714
103    iterator = iter(it)
104    while chunk := list(islice(iterator, chunk_size)):
105        yield chunk
106
107
108def get_checkpoint_paths_for_run(
109    run_path: Path,
110    extension: typing.Literal["pt", "zanj"],
111    checkpoints_format: str = "checkpoints/model.iter_*.{extension}",
112) -> list[tuple[int, Path]]:
113    """get checkpoints of the format from the run_path
114
115    note that `checkpoints_format` should contain a glob pattern with:
116     - unresolved "{extension}" format term for the extension
117     - a wildcard for the iteration number
118    """
119
120    assert run_path.is_dir(), f"Model path {run_path} is not a directory (expect run directory, not model files)"
121
122    return [
123        (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path)
124        for checkpoint_path in sorted(
125            Path(run_path).glob(checkpoints_format.format(extension=extension))
126        )
127    ]
128
129
130F = TypeVar("F", bound=Callable[..., Any])
131
132
133def register_method(
134    method_dict: dict[str, Callable[..., Any]],
135    custom_name: Optional[str] = None,
136) -> Callable[[F], F]:
137    """Decorator to add a method to the method_dict"""
138
139    def decorator(method: F) -> F:
140        method_name: str
141        if custom_name is None:
142            method_name_orig: str | None = getattr(method, "__name__", None)
143            if method_name_orig is None:
144                warnings.warn(
145                    f"Method {method} does not have a name, using sanitized repr"
146                )
147                from muutils.misc import sanitize_identifier
148
149                method_name = sanitize_identifier(repr(method))
150            else:
151                method_name = method_name_orig
152        else:
153            method_name = custom_name
154            method.__name__ = custom_name
155        assert (
156            method_name not in method_dict
157        ), f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }"
158        method_dict[method_name] = method
159        return method
160
161    return decorator
162
163
164def pprint_summary(summary: dict):
165    print(json.dumps(summary, indent=2))

ARRAY_IMPORTS: bool = True
DEFAULT_SEED: int = 42
GLOBAL_SEED: int = 42
def get_device(device: Union[str, torch.device, NoneType] = None) -> torch.device:
31def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device":
32    """Get the torch.device instance on which `torch.Tensor`s should be allocated."""
33    if not ARRAY_IMPORTS:
34        raise ImportError(
35            "Numpy or torch not installed. Array operations will not be available."
36        )
37    try:
38        # if device is given
39        if device is not None:
40            device = torch.device(device)
41            if any(
42                [
43                    torch.cuda.is_available() and device.type == "cuda",
44                    torch.backends.mps.is_available() and device.type == "mps",
45                    device.type == "cpu",
46                ]
47            ):
48                # if device is given and available
49                pass
50            else:
51                warnings.warn(
52                    f"Specified device {device} is not available, falling back to CPU"
53                )
54                return torch.device("cpu")
55
56        # no device given, infer from availability
57        else:
58            if torch.cuda.is_available():
59                device = torch.device("cuda")
60            elif torch.backends.mps.is_available():
61                device = torch.device("mps")
62            else:
63                device = torch.device("cpu")
64
65        # put a dummy tensor on the device to check if it is available
66        _dummy = torch.zeros(1, device=device)
67
68        return device
69
70    except Exception as e:
71        warnings.warn(
72            f"Error while getting device, falling back to CPU. Error: {e}",
73            RuntimeWarning,
74        )
75        return torch.device("cpu")

Get the torch.device instance on which torch.Tensors should be allocated.

def set_reproducibility(seed: int = 42):
78def set_reproducibility(seed: int = DEFAULT_SEED):
79    """
80    Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
81
82    Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades
83    off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
84    """
85    global GLOBAL_SEED
86
87    GLOBAL_SEED = seed
88
89    random.seed(seed)
90
91    if ARRAY_IMPORTS:
92        np.random.seed(seed)
93        torch.manual_seed(seed)
94
95        torch.use_deterministic_algorithms(True)
96        # Ensure reproducibility for concurrent CUDA streams
97        # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility.
98        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.

Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.

def chunks(it, chunk_size):
101def chunks(it, chunk_size):
102    """Yield successive chunks from an iterator."""
103    # https://stackoverflow.com/a/61435714
104    iterator = iter(it)
105    while chunk := list(islice(iterator, chunk_size)):
106        yield chunk

Yield successive chunks from an iterator.

def get_checkpoint_paths_for_run( run_path: pathlib.Path, extension: Literal['pt', 'zanj'], checkpoints_format: str = 'checkpoints/model.iter_*.{extension}') -> list[tuple[int, pathlib.Path]]:
109def get_checkpoint_paths_for_run(
110    run_path: Path,
111    extension: typing.Literal["pt", "zanj"],
112    checkpoints_format: str = "checkpoints/model.iter_*.{extension}",
113) -> list[tuple[int, Path]]:
114    """get checkpoints of the format from the run_path
115
116    note that `checkpoints_format` should contain a glob pattern with:
117     - unresolved "{extension}" format term for the extension
118     - a wildcard for the iteration number
119    """
120
121    assert run_path.is_dir(), f"Model path {run_path} is not a directory (expect run directory, not model files)"
122
123    return [
124        (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path)
125        for checkpoint_path in sorted(
126            Path(run_path).glob(checkpoints_format.format(extension=extension))
127        )
128    ]

get checkpoints of the format from the run_path

note that checkpoints_format should contain a glob pattern with:

  • unresolved "{extension}" format term for the extension
  • a wildcard for the iteration number
def register_method( method_dict: dict[str, typing.Callable[..., typing.Any]], custom_name: Optional[str] = None) -> Callable[[~F], ~F]:
134def register_method(
135    method_dict: dict[str, Callable[..., Any]],
136    custom_name: Optional[str] = None,
137) -> Callable[[F], F]:
138    """Decorator to add a method to the method_dict"""
139
140    def decorator(method: F) -> F:
141        method_name: str
142        if custom_name is None:
143            method_name_orig: str | None = getattr(method, "__name__", None)
144            if method_name_orig is None:
145                warnings.warn(
146                    f"Method {method} does not have a name, using sanitized repr"
147                )
148                from muutils.misc import sanitize_identifier
149
150                method_name = sanitize_identifier(repr(method))
151            else:
152                method_name = method_name_orig
153        else:
154            method_name = custom_name
155            method.__name__ = custom_name
156        assert (
157            method_name not in method_dict
158        ), f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }"
159        method_dict[method_name] = method
160        return method
161
162    return decorator

Decorator to add a method to the method_dict

def pprint_summary(summary: dict):
165def pprint_summary(summary: dict):
166    print(json.dumps(summary, indent=2))