muutils.mlutils
miscellaneous utilities for ML pipelines
1"miscellaneous utilities for ML pipelines" 2 3from __future__ import annotations 4 5import json 6import os 7import random 8import typing 9import warnings 10from itertools import islice 11from pathlib import Path 12from typing import Any, Callable, Optional, TypeVar, Union 13 14ARRAY_IMPORTS: bool 15try: 16 import numpy as np 17 import torch 18 19 ARRAY_IMPORTS = True 20except ImportError as e: 21 warnings.warn( 22 f"Numpy or torch not installed. Array operations will not be available.\n{e}" 23 ) 24 ARRAY_IMPORTS = False 25 26DEFAULT_SEED: int = 42 27GLOBAL_SEED: int = DEFAULT_SEED 28 29 30def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device": 31 """Get the torch.device instance on which `torch.Tensor`s should be allocated.""" 32 if not ARRAY_IMPORTS: 33 raise ImportError( 34 "Numpy or torch not installed. Array operations will not be available." 35 ) 36 try: 37 # if device is given 38 if device is not None: 39 device = torch.device(device) 40 if any( 41 [ 42 torch.cuda.is_available() and device.type == "cuda", 43 torch.backends.mps.is_available() and device.type == "mps", 44 device.type == "cpu", 45 ] 46 ): 47 # if device is given and available 48 pass 49 else: 50 warnings.warn( 51 f"Specified device {device} is not available, falling back to CPU" 52 ) 53 return torch.device("cpu") 54 55 # no device given, infer from availability 56 else: 57 if torch.cuda.is_available(): 58 device = torch.device("cuda") 59 elif torch.backends.mps.is_available(): 60 device = torch.device("mps") 61 else: 62 device = torch.device("cpu") 63 64 # put a dummy tensor on the device to check if it is available 65 _dummy = torch.zeros(1, device=device) 66 67 return device 68 69 except Exception as e: 70 warnings.warn( 71 f"Error while getting device, falling back to CPU. Error: {e}", 72 RuntimeWarning, 73 ) 74 return torch.device("cpu") 75 76 77def set_reproducibility(seed: int = DEFAULT_SEED): 78 """ 79 Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information. 80 81 Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades 82 off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance. 83 """ 84 global GLOBAL_SEED 85 86 GLOBAL_SEED = seed 87 88 random.seed(seed) 89 90 if ARRAY_IMPORTS: 91 np.random.seed(seed) 92 torch.manual_seed(seed) 93 94 torch.use_deterministic_algorithms(True) 95 # Ensure reproducibility for concurrent CUDA streams 96 # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility. 97 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 98 99 100def chunks(it, chunk_size): 101 """Yield successive chunks from an iterator.""" 102 # https://stackoverflow.com/a/61435714 103 iterator = iter(it) 104 while chunk := list(islice(iterator, chunk_size)): 105 yield chunk 106 107 108def get_checkpoint_paths_for_run( 109 run_path: Path, 110 extension: typing.Literal["pt", "zanj"], 111 checkpoints_format: str = "checkpoints/model.iter_*.{extension}", 112) -> list[tuple[int, Path]]: 113 """get checkpoints of the format from the run_path 114 115 note that `checkpoints_format` should contain a glob pattern with: 116 - unresolved "{extension}" format term for the extension 117 - a wildcard for the iteration number 118 """ 119 120 assert run_path.is_dir(), f"Model path {run_path} is not a directory (expect run directory, not model files)" 121 122 return [ 123 (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path) 124 for checkpoint_path in sorted( 125 Path(run_path).glob(checkpoints_format.format(extension=extension)) 126 ) 127 ] 128 129 130F = TypeVar("F", bound=Callable[..., Any]) 131 132 133def register_method( 134 method_dict: dict[str, Callable[..., Any]], 135 custom_name: Optional[str] = None, 136) -> Callable[[F], F]: 137 """Decorator to add a method to the method_dict""" 138 139 def decorator(method: F) -> F: 140 method_name: str 141 if custom_name is None: 142 method_name_orig: str | None = getattr(method, "__name__", None) 143 if method_name_orig is None: 144 warnings.warn( 145 f"Method {method} does not have a name, using sanitized repr" 146 ) 147 from muutils.misc import sanitize_identifier 148 149 method_name = sanitize_identifier(repr(method)) 150 else: 151 method_name = method_name_orig 152 else: 153 method_name = custom_name 154 method.__name__ = custom_name 155 assert ( 156 method_name not in method_dict 157 ), f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }" 158 method_dict[method_name] = method 159 return method 160 161 return decorator 162 163 164def pprint_summary(summary: dict): 165 print(json.dumps(summary, indent=2))
ARRAY_IMPORTS: bool =
True
DEFAULT_SEED: int =
42
GLOBAL_SEED: int =
42
def
get_device(device: Union[str, torch.device, NoneType] = None) -> torch.device:
31def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device": 32 """Get the torch.device instance on which `torch.Tensor`s should be allocated.""" 33 if not ARRAY_IMPORTS: 34 raise ImportError( 35 "Numpy or torch not installed. Array operations will not be available." 36 ) 37 try: 38 # if device is given 39 if device is not None: 40 device = torch.device(device) 41 if any( 42 [ 43 torch.cuda.is_available() and device.type == "cuda", 44 torch.backends.mps.is_available() and device.type == "mps", 45 device.type == "cpu", 46 ] 47 ): 48 # if device is given and available 49 pass 50 else: 51 warnings.warn( 52 f"Specified device {device} is not available, falling back to CPU" 53 ) 54 return torch.device("cpu") 55 56 # no device given, infer from availability 57 else: 58 if torch.cuda.is_available(): 59 device = torch.device("cuda") 60 elif torch.backends.mps.is_available(): 61 device = torch.device("mps") 62 else: 63 device = torch.device("cpu") 64 65 # put a dummy tensor on the device to check if it is available 66 _dummy = torch.zeros(1, device=device) 67 68 return device 69 70 except Exception as e: 71 warnings.warn( 72 f"Error while getting device, falling back to CPU. Error: {e}", 73 RuntimeWarning, 74 ) 75 return torch.device("cpu")
Get the torch.device instance on which torch.Tensor
s should be allocated.
def
set_reproducibility(seed: int = 42):
78def set_reproducibility(seed: int = DEFAULT_SEED): 79 """ 80 Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information. 81 82 Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades 83 off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance. 84 """ 85 global GLOBAL_SEED 86 87 GLOBAL_SEED = seed 88 89 random.seed(seed) 90 91 if ARRAY_IMPORTS: 92 np.random.seed(seed) 93 torch.manual_seed(seed) 94 95 torch.use_deterministic_algorithms(True) 96 # Ensure reproducibility for concurrent CUDA streams 97 # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility. 98 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
def
chunks(it, chunk_size):
101def chunks(it, chunk_size): 102 """Yield successive chunks from an iterator.""" 103 # https://stackoverflow.com/a/61435714 104 iterator = iter(it) 105 while chunk := list(islice(iterator, chunk_size)): 106 yield chunk
Yield successive chunks from an iterator.
def
get_checkpoint_paths_for_run( run_path: pathlib.Path, extension: Literal['pt', 'zanj'], checkpoints_format: str = 'checkpoints/model.iter_*.{extension}') -> list[tuple[int, pathlib.Path]]:
109def get_checkpoint_paths_for_run( 110 run_path: Path, 111 extension: typing.Literal["pt", "zanj"], 112 checkpoints_format: str = "checkpoints/model.iter_*.{extension}", 113) -> list[tuple[int, Path]]: 114 """get checkpoints of the format from the run_path 115 116 note that `checkpoints_format` should contain a glob pattern with: 117 - unresolved "{extension}" format term for the extension 118 - a wildcard for the iteration number 119 """ 120 121 assert run_path.is_dir(), f"Model path {run_path} is not a directory (expect run directory, not model files)" 122 123 return [ 124 (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path) 125 for checkpoint_path in sorted( 126 Path(run_path).glob(checkpoints_format.format(extension=extension)) 127 ) 128 ]
get checkpoints of the format from the run_path
note that checkpoints_format
should contain a glob pattern with:
- unresolved "{extension}" format term for the extension
- a wildcard for the iteration number
def
register_method( method_dict: dict[str, typing.Callable[..., typing.Any]], custom_name: Optional[str] = None) -> Callable[[~F], ~F]:
134def register_method( 135 method_dict: dict[str, Callable[..., Any]], 136 custom_name: Optional[str] = None, 137) -> Callable[[F], F]: 138 """Decorator to add a method to the method_dict""" 139 140 def decorator(method: F) -> F: 141 method_name: str 142 if custom_name is None: 143 method_name_orig: str | None = getattr(method, "__name__", None) 144 if method_name_orig is None: 145 warnings.warn( 146 f"Method {method} does not have a name, using sanitized repr" 147 ) 148 from muutils.misc import sanitize_identifier 149 150 method_name = sanitize_identifier(repr(method)) 151 else: 152 method_name = method_name_orig 153 else: 154 method_name = custom_name 155 method.__name__ = custom_name 156 assert ( 157 method_name not in method_dict 158 ), f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }" 159 method_dict[method_name] = method 160 return method 161 162 return decorator
Decorator to add a method to the method_dict
def
pprint_summary(summary: dict):