muutils.tensor_info
1import numpy as np 2from typing import Union, Any, Literal, List, Dict, overload 3 4# Global color definitions 5COLORS: Dict[str, Dict[str, str]] = { 6 "latex": { 7 "range": r"\textcolor{purple}", 8 "mean": r"\textcolor{teal}", 9 "std": r"\textcolor{orange}", 10 "median": r"\textcolor{green}", 11 "warning": r"\textcolor{red}", 12 "shape": r"\textcolor{magenta}", 13 "dtype": r"\textcolor{gray}", 14 "device": r"\textcolor{gray}", 15 "requires_grad": r"\textcolor{gray}", 16 "sparkline": r"\textcolor{blue}", 17 "reset": "", 18 }, 19 "terminal": { 20 "range": "\033[35m", # purple 21 "mean": "\033[36m", # cyan/teal 22 "std": "\033[33m", # yellow/orange 23 "median": "\033[32m", # green 24 "warning": "\033[31m", # red 25 "shape": "\033[95m", # bright magenta 26 "dtype": "\033[90m", # gray 27 "device": "\033[90m", # gray 28 "requires_grad": "\033[90m", # gray 29 "sparkline": "\033[34m", # blue 30 "reset": "\033[0m", 31 }, 32 "none": { 33 "range": "", 34 "mean": "", 35 "std": "", 36 "median": "", 37 "warning": "", 38 "shape": "", 39 "dtype": "", 40 "device": "", 41 "requires_grad": "", 42 "sparkline": "", 43 "reset": "", 44 }, 45} 46 47OutputFormat = Literal["unicode", "latex", "ascii"] 48 49SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { 50 "latex": { 51 "range": r"\mathcal{R}", 52 "mean": r"\mu", 53 "std": r"\sigma", 54 "median": r"\tilde{x}", 55 "distribution": r"\mathbb{P}", 56 "nan_values": r"\text{NANvals}", 57 "warning": "!!!", 58 "requires_grad": r"\nabla", 59 "true": r"\checkmark", 60 "false": r"\times", 61 }, 62 "unicode": { 63 "range": "R", 64 "mean": "μ", 65 "std": "σ", 66 "median": "x̃", 67 "distribution": "ℙ", 68 "nan_values": "NANvals", 69 "warning": "🚨", 70 "requires_grad": "∇", 71 "true": "✓", 72 "false": "✗", 73 }, 74 "ascii": { 75 "range": "range", 76 "mean": "mean", 77 "std": "std", 78 "median": "med", 79 "distribution": "dist", 80 "nan_values": "NANvals", 81 "warning": "!!!", 82 "requires_grad": "requires_grad", 83 "true": "1", 84 "false": "0", 85 }, 86} 87"Symbols for different formats" 88 89SPARK_CHARS: Dict[OutputFormat, List[str]] = { 90 "unicode": list(" ▁▂▃▄▅▆▇█"), 91 "ascii": list(" _.-~=#"), 92 "latex": list(" ▁▂▃▄▅▆▇█"), 93} 94"characters for sparklines in different formats" 95 96 97def array_info( 98 A: Any, 99 hist_bins: int = 5, 100) -> Dict[str, Any]: 101 """Extract statistical information from an array-like object. 102 103 # Parameters: 104 - `A : array-like` 105 Array to analyze (numpy array or torch tensor) 106 107 # Returns: 108 - `Dict[str, Any]` 109 Dictionary containing raw statistical information with numeric values 110 """ 111 result: Dict[str, Any] = { 112 "is_tensor": None, 113 "device": None, 114 "requires_grad": None, 115 "shape": None, 116 "dtype": None, 117 "size": None, 118 "has_nans": None, 119 "nan_count": None, 120 "nan_percent": None, 121 "min": None, 122 "max": None, 123 "range": None, 124 "mean": None, 125 "std": None, 126 "median": None, 127 "histogram": None, 128 "bins": None, 129 "status": None, 130 } 131 132 # Check if it's a tensor by looking at its class name 133 # This avoids importing torch directly 134 A_type: str = type(A).__name__ 135 result["is_tensor"] = A_type == "Tensor" 136 137 # Try to get device information if it's a tensor 138 if result["is_tensor"]: 139 try: 140 result["device"] = str(getattr(A, "device", None)) 141 except: # noqa: E722 142 pass 143 144 # Convert to numpy array for calculations 145 try: 146 # For PyTorch tensors 147 if result["is_tensor"]: 148 # Check if tensor is on GPU 149 is_cuda: bool = False 150 try: 151 is_cuda = bool(getattr(A, "is_cuda", False)) 152 except: # noqa: E722 153 pass 154 155 if is_cuda: 156 try: 157 # Try to get CPU tensor first 158 cpu_tensor = getattr(A, "cpu", lambda: A)() 159 except: # noqa: E722 160 A_np = np.array([]) 161 else: 162 cpu_tensor = A 163 try: 164 # For CPU tensor, just detach and convert 165 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 166 A_np = getattr(detached, "numpy", lambda: np.array([]))() 167 except: # noqa: E722 168 A_np = np.array([]) 169 else: 170 # For numpy arrays and other array-like objects 171 A_np = np.asarray(A) 172 except: # noqa: E722 173 A_np = np.array([]) 174 175 # Get basic information 176 try: 177 result["shape"] = A_np.shape 178 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 179 result["size"] = A_np.size 180 result["requires_grad"] = getattr(A, "requires_grad", None) 181 except: # noqa: E722 182 pass 183 184 # If array is empty, return early 185 if result["size"] == 0: 186 result["status"] = "empty array" 187 return result 188 189 # Flatten array for statistics if it's multi-dimensional 190 try: 191 if len(A_np.shape) > 1: 192 A_flat = A_np.flatten() 193 else: 194 A_flat = A_np 195 except: # noqa: E722 196 A_flat = A_np 197 198 # Check for NaN values 199 try: 200 nan_mask = np.isnan(A_flat) 201 result["nan_count"] = np.sum(nan_mask) 202 result["has_nans"] = result["nan_count"] > 0 203 if result["size"] > 0: 204 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 205 except: # noqa: E722 206 pass 207 208 # If all values are NaN, return early 209 if result["has_nans"] and result["nan_count"] == result["size"]: 210 result["status"] = "all NaN" 211 return result 212 213 # Calculate statistics 214 try: 215 if result["has_nans"]: 216 result["min"] = float(np.nanmin(A_flat)) 217 result["max"] = float(np.nanmax(A_flat)) 218 result["mean"] = float(np.nanmean(A_flat)) 219 result["std"] = float(np.nanstd(A_flat)) 220 result["median"] = float(np.nanmedian(A_flat)) 221 result["range"] = (result["min"], result["max"]) 222 223 # Remove NaNs for histogram 224 A_hist = A_flat[~nan_mask] 225 else: 226 result["min"] = float(np.min(A_flat)) 227 result["max"] = float(np.max(A_flat)) 228 result["mean"] = float(np.mean(A_flat)) 229 result["std"] = float(np.std(A_flat)) 230 result["median"] = float(np.median(A_flat)) 231 result["range"] = (result["min"], result["max"]) 232 233 A_hist = A_flat 234 235 # Calculate histogram data for sparklines 236 if A_hist.size > 0: 237 try: 238 hist, bins = np.histogram(A_hist, bins=hist_bins) 239 result["histogram"] = hist 240 result["bins"] = bins 241 except: # noqa: E722 242 pass 243 244 result["status"] = "ok" 245 except Exception as e: 246 result["status"] = f"error: {str(e)}" 247 248 return result 249 250 251def generate_sparkline( 252 histogram: np.ndarray, 253 format: Literal["unicode", "latex", "ascii"] = "unicode", 254 log_y: bool = False, 255) -> str: 256 """Generate a sparkline visualization of the histogram. 257 258 # Parameters: 259 - `histogram : np.ndarray` 260 Histogram data 261 - `format : Literal["unicode", "latex", "ascii"]` 262 Output format (defaults to `"unicode"`) 263 - `log_y : bool` 264 Whether to use logarithmic y-scale (defaults to `False`) 265 266 # Returns: 267 - `str` 268 Sparkline visualization 269 """ 270 if histogram is None or len(histogram) == 0: 271 return "" 272 273 # Get the appropriate character set 274 if format in SPARK_CHARS: 275 chars = SPARK_CHARS[format] 276 else: 277 chars = SPARK_CHARS["ascii"] 278 279 # Handle log scale 280 if log_y: 281 # Add small value to avoid log(0) 282 hist_data = np.log1p(histogram) 283 else: 284 hist_data = histogram 285 286 # Normalize to character set range 287 if hist_data.max() > 0: 288 normalized = hist_data / hist_data.max() * (len(chars) - 1) 289 else: 290 normalized = np.zeros_like(hist_data) 291 292 # Convert to characters 293 spark = "" 294 for val in normalized: 295 idx = int(val) 296 spark += chars[idx] 297 298 return spark 299 300 301DEFAULT_SETTINGS: Dict[str, Any] = dict( 302 fmt="unicode", 303 precision=2, 304 stats=True, 305 shape=True, 306 dtype=True, 307 device=True, 308 requires_grad=True, 309 sparkline=False, 310 sparkline_bins=5, 311 sparkline_logy=False, 312 colored=False, 313 as_list=False, 314 eq_char="=", 315) 316 317 318class _UseDefaultType: 319 pass 320 321 322_USE_DEFAULT = _UseDefaultType() 323 324 325@overload 326def array_summary( 327 as_list: Literal[True], 328 **kwargs, 329) -> List[str]: ... 330@overload 331def array_summary( 332 as_list: Literal[False], 333 **kwargs, 334) -> str: ... 335def array_summary( # type: ignore[misc] 336 array, 337 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 338 precision: int = _USE_DEFAULT, # type: ignore[assignment] 339 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 340 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 341 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 342 device: bool = _USE_DEFAULT, # type: ignore[assignment] 343 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 344 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 345 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 346 sparkline_logy: bool = _USE_DEFAULT, # type: ignore[assignment] 347 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 348 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 349 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 350) -> Union[str, List[str]]: 351 """Format array information into a readable summary. 352 353 # Parameters: 354 - `array` 355 array-like object (numpy array or torch tensor) 356 - `precision : int` 357 Decimal places (defaults to `2`) 358 - `format : Literal["unicode", "latex", "ascii"]` 359 Output format (defaults to `{default_fmt}`) 360 - `stats : bool` 361 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 362 - `shape : bool` 363 Whether to include shape info (defaults to `True`) 364 - `dtype : bool` 365 Whether to include dtype info (defaults to `True`) 366 - `device : bool` 367 Whether to include device info for torch tensors (defaults to `True`) 368 - `requires_grad : bool` 369 Whether to include requires_grad info for torch tensors (defaults to `True`) 370 - `sparkline : bool` 371 Whether to include a sparkline visualization (defaults to `False`) 372 - `sparkline_width : int` 373 Width of the sparkline (defaults to `20`) 374 - `sparkline_logy : bool` 375 Whether to use logarithmic y-scale for sparkline (defaults to `False`) 376 - `colored : bool` 377 Whether to add color to output (defaults to `False`) 378 - `as_list : bool` 379 Whether to return as list of strings instead of joined string (defaults to `False`) 380 381 # Returns: 382 - `Union[str, List[str]]` 383 Formatted statistical summary, either as string or list of strings 384 """ 385 if fmt is _USE_DEFAULT: 386 fmt = DEFAULT_SETTINGS["fmt"] 387 if precision is _USE_DEFAULT: 388 precision = DEFAULT_SETTINGS["precision"] 389 if stats is _USE_DEFAULT: 390 stats = DEFAULT_SETTINGS["stats"] 391 if shape is _USE_DEFAULT: 392 shape = DEFAULT_SETTINGS["shape"] 393 if dtype is _USE_DEFAULT: 394 dtype = DEFAULT_SETTINGS["dtype"] 395 if device is _USE_DEFAULT: 396 device = DEFAULT_SETTINGS["device"] 397 if requires_grad is _USE_DEFAULT: 398 requires_grad = DEFAULT_SETTINGS["requires_grad"] 399 if sparkline is _USE_DEFAULT: 400 sparkline = DEFAULT_SETTINGS["sparkline"] 401 if sparkline_bins is _USE_DEFAULT: 402 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 403 if sparkline_logy is _USE_DEFAULT: 404 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 405 if colored is _USE_DEFAULT: 406 colored = DEFAULT_SETTINGS["colored"] 407 if as_list is _USE_DEFAULT: 408 as_list = DEFAULT_SETTINGS["as_list"] 409 if eq_char is _USE_DEFAULT: 410 eq_char = DEFAULT_SETTINGS["eq_char"] 411 412 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 413 result_parts: List[str] = [] 414 using_tex: bool = fmt == "latex" 415 416 # Set color scheme based on format and colored flag 417 colors: Dict[str, str] 418 if colored: 419 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 420 else: 421 colors = COLORS["none"] 422 423 # Get symbols for the current format 424 symbols: Dict[str, str] = SYMBOLS[fmt] 425 426 # Helper function to colorize text 427 def colorize(text: str, color_key: str) -> str: 428 if using_tex: 429 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 430 else: 431 return ( 432 f"{colors[color_key]}{text}{colors['reset']}" 433 if colors[color_key] 434 else text 435 ) 436 437 # Format string for numbers 438 float_fmt: str = f".{precision}f" 439 440 # Handle error status or empty array 441 if ( 442 array_data["status"] in ["empty array", "all NaN", "unknown"] 443 or array_data["size"] == 0 444 ): 445 status = array_data["status"] 446 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 447 else: 448 # Add NaN warning at the beginning if there are NaNs 449 if array_data["has_nans"]: 450 _percent: str = "\\%" if using_tex else "%" 451 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 452 result_parts.append(colorize(nan_str, "warning")) 453 454 # Statistics 455 if stats: 456 for stat_key in ["mean", "std", "median"]: 457 if array_data[stat_key] is not None: 458 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 459 stat_colored: str = colorize(stat_str, stat_key) 460 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 461 462 # Range (min, max) 463 if array_data["range"] is not None: 464 min_val, max_val = array_data["range"] 465 min_str: str = f"{min_val:{float_fmt}}" 466 max_str: str = f"{max_val:{float_fmt}}" 467 min_colored: str = colorize(min_str, "range") 468 max_colored: str = colorize(max_str, "range") 469 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 470 result_parts.append(range_str) 471 472 # Add sparkline if requested 473 if sparkline and array_data["histogram"] is not None: 474 spark = generate_sparkline( 475 array_data["histogram"], format=fmt, log_y=sparkline_logy 476 ) 477 if spark: 478 spark_colored = colorize(spark, "sparkline") 479 result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|") 480 481 # Add shape if requested 482 if shape and array_data["shape"]: 483 shape_val = array_data["shape"] 484 if len(shape_val) == 1: 485 shape_str = str(shape_val[0]) 486 else: 487 shape_str = ( 488 "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")" 489 ) 490 result_parts.append(f"shape{eq_char}{shape_str}") 491 492 # Add dtype if requested 493 if dtype and array_data["dtype"]: 494 result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype")) 495 496 # Add device if requested and it's a tensor with device info 497 if device and array_data["is_tensor"] and array_data["device"]: 498 result_parts.append( 499 colorize(f"device{eq_char}{array_data['device']}", "device") 500 ) 501 502 # Add gradient info 503 if requires_grad and array_data["is_tensor"]: 504 bool_req_grad_symb: str = ( 505 symbols["true"] if array_data["requires_grad"] else symbols["false"] 506 ) 507 result_parts.append( 508 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 509 ) 510 511 # Return as list if requested, otherwise join with spaces 512 if as_list: 513 return result_parts 514 else: 515 joinchar: str = r" \quad " if using_tex else " " 516 return joinchar.join(result_parts)
COLORS: Dict[str, Dict[str, str]] =
{'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'reset': ''}}
OutputFormat =
typing.Literal['unicode', 'latex', 'ascii']
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] =
{'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'nan_values': '\\text{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla', 'true': '\\checkmark', 'false': '\\times'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇', 'true': '✓', 'false': '✗'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad', 'true': '1', 'false': '0'}}
Symbols for different formats
SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] =
{'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}
characters for sparklines in different formats
def
array_info(A: Any, hist_bins: int = 5) -> Dict[str, Any]:
98def array_info( 99 A: Any, 100 hist_bins: int = 5, 101) -> Dict[str, Any]: 102 """Extract statistical information from an array-like object. 103 104 # Parameters: 105 - `A : array-like` 106 Array to analyze (numpy array or torch tensor) 107 108 # Returns: 109 - `Dict[str, Any]` 110 Dictionary containing raw statistical information with numeric values 111 """ 112 result: Dict[str, Any] = { 113 "is_tensor": None, 114 "device": None, 115 "requires_grad": None, 116 "shape": None, 117 "dtype": None, 118 "size": None, 119 "has_nans": None, 120 "nan_count": None, 121 "nan_percent": None, 122 "min": None, 123 "max": None, 124 "range": None, 125 "mean": None, 126 "std": None, 127 "median": None, 128 "histogram": None, 129 "bins": None, 130 "status": None, 131 } 132 133 # Check if it's a tensor by looking at its class name 134 # This avoids importing torch directly 135 A_type: str = type(A).__name__ 136 result["is_tensor"] = A_type == "Tensor" 137 138 # Try to get device information if it's a tensor 139 if result["is_tensor"]: 140 try: 141 result["device"] = str(getattr(A, "device", None)) 142 except: # noqa: E722 143 pass 144 145 # Convert to numpy array for calculations 146 try: 147 # For PyTorch tensors 148 if result["is_tensor"]: 149 # Check if tensor is on GPU 150 is_cuda: bool = False 151 try: 152 is_cuda = bool(getattr(A, "is_cuda", False)) 153 except: # noqa: E722 154 pass 155 156 if is_cuda: 157 try: 158 # Try to get CPU tensor first 159 cpu_tensor = getattr(A, "cpu", lambda: A)() 160 except: # noqa: E722 161 A_np = np.array([]) 162 else: 163 cpu_tensor = A 164 try: 165 # For CPU tensor, just detach and convert 166 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 167 A_np = getattr(detached, "numpy", lambda: np.array([]))() 168 except: # noqa: E722 169 A_np = np.array([]) 170 else: 171 # For numpy arrays and other array-like objects 172 A_np = np.asarray(A) 173 except: # noqa: E722 174 A_np = np.array([]) 175 176 # Get basic information 177 try: 178 result["shape"] = A_np.shape 179 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 180 result["size"] = A_np.size 181 result["requires_grad"] = getattr(A, "requires_grad", None) 182 except: # noqa: E722 183 pass 184 185 # If array is empty, return early 186 if result["size"] == 0: 187 result["status"] = "empty array" 188 return result 189 190 # Flatten array for statistics if it's multi-dimensional 191 try: 192 if len(A_np.shape) > 1: 193 A_flat = A_np.flatten() 194 else: 195 A_flat = A_np 196 except: # noqa: E722 197 A_flat = A_np 198 199 # Check for NaN values 200 try: 201 nan_mask = np.isnan(A_flat) 202 result["nan_count"] = np.sum(nan_mask) 203 result["has_nans"] = result["nan_count"] > 0 204 if result["size"] > 0: 205 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 206 except: # noqa: E722 207 pass 208 209 # If all values are NaN, return early 210 if result["has_nans"] and result["nan_count"] == result["size"]: 211 result["status"] = "all NaN" 212 return result 213 214 # Calculate statistics 215 try: 216 if result["has_nans"]: 217 result["min"] = float(np.nanmin(A_flat)) 218 result["max"] = float(np.nanmax(A_flat)) 219 result["mean"] = float(np.nanmean(A_flat)) 220 result["std"] = float(np.nanstd(A_flat)) 221 result["median"] = float(np.nanmedian(A_flat)) 222 result["range"] = (result["min"], result["max"]) 223 224 # Remove NaNs for histogram 225 A_hist = A_flat[~nan_mask] 226 else: 227 result["min"] = float(np.min(A_flat)) 228 result["max"] = float(np.max(A_flat)) 229 result["mean"] = float(np.mean(A_flat)) 230 result["std"] = float(np.std(A_flat)) 231 result["median"] = float(np.median(A_flat)) 232 result["range"] = (result["min"], result["max"]) 233 234 A_hist = A_flat 235 236 # Calculate histogram data for sparklines 237 if A_hist.size > 0: 238 try: 239 hist, bins = np.histogram(A_hist, bins=hist_bins) 240 result["histogram"] = hist 241 result["bins"] = bins 242 except: # noqa: E722 243 pass 244 245 result["status"] = "ok" 246 except Exception as e: 247 result["status"] = f"error: {str(e)}" 248 249 return result
Extract statistical information from an array-like object.
Parameters:
A : array-like
Array to analyze (numpy array or torch tensor)
Returns:
Dict[str, Any]
Dictionary containing raw statistical information with numeric values
def
generate_sparkline( histogram: numpy.ndarray, format: Literal['unicode', 'latex', 'ascii'] = 'unicode', log_y: bool = False) -> str:
252def generate_sparkline( 253 histogram: np.ndarray, 254 format: Literal["unicode", "latex", "ascii"] = "unicode", 255 log_y: bool = False, 256) -> str: 257 """Generate a sparkline visualization of the histogram. 258 259 # Parameters: 260 - `histogram : np.ndarray` 261 Histogram data 262 - `format : Literal["unicode", "latex", "ascii"]` 263 Output format (defaults to `"unicode"`) 264 - `log_y : bool` 265 Whether to use logarithmic y-scale (defaults to `False`) 266 267 # Returns: 268 - `str` 269 Sparkline visualization 270 """ 271 if histogram is None or len(histogram) == 0: 272 return "" 273 274 # Get the appropriate character set 275 if format in SPARK_CHARS: 276 chars = SPARK_CHARS[format] 277 else: 278 chars = SPARK_CHARS["ascii"] 279 280 # Handle log scale 281 if log_y: 282 # Add small value to avoid log(0) 283 hist_data = np.log1p(histogram) 284 else: 285 hist_data = histogram 286 287 # Normalize to character set range 288 if hist_data.max() > 0: 289 normalized = hist_data / hist_data.max() * (len(chars) - 1) 290 else: 291 normalized = np.zeros_like(hist_data) 292 293 # Convert to characters 294 spark = "" 295 for val in normalized: 296 idx = int(val) 297 spark += chars[idx] 298 299 return spark
Generate a sparkline visualization of the histogram.
Parameters:
histogram : np.ndarray
Histogram dataformat : Literal["unicode", "latex", "ascii"]
Output format (defaults to"unicode"
)log_y : bool
Whether to use logarithmic y-scale (defaults toFalse
)
Returns:
str
Sparkline visualization
DEFAULT_SETTINGS: Dict[str, Any] =
{'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': False, 'colored': False, 'as_list': False, 'eq_char': '='}
def
array_summary( array, fmt: Literal['unicode', 'latex', 'ascii'] = <muutils.tensor_info._UseDefaultType object>, precision: int = <muutils.tensor_info._UseDefaultType object>, stats: bool = <muutils.tensor_info._UseDefaultType object>, shape: bool = <muutils.tensor_info._UseDefaultType object>, dtype: bool = <muutils.tensor_info._UseDefaultType object>, device: bool = <muutils.tensor_info._UseDefaultType object>, requires_grad: bool = <muutils.tensor_info._UseDefaultType object>, sparkline: bool = <muutils.tensor_info._UseDefaultType object>, sparkline_bins: int = <muutils.tensor_info._UseDefaultType object>, sparkline_logy: bool = <muutils.tensor_info._UseDefaultType object>, colored: bool = <muutils.tensor_info._UseDefaultType object>, eq_char: str = <muutils.tensor_info._UseDefaultType object>, as_list: bool = <muutils.tensor_info._UseDefaultType object>) -> Union[str, List[str]]:
336def array_summary( # type: ignore[misc] 337 array, 338 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 339 precision: int = _USE_DEFAULT, # type: ignore[assignment] 340 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 341 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 342 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 343 device: bool = _USE_DEFAULT, # type: ignore[assignment] 344 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 345 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 346 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 347 sparkline_logy: bool = _USE_DEFAULT, # type: ignore[assignment] 348 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 349 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 350 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 351) -> Union[str, List[str]]: 352 """Format array information into a readable summary. 353 354 # Parameters: 355 - `array` 356 array-like object (numpy array or torch tensor) 357 - `precision : int` 358 Decimal places (defaults to `2`) 359 - `format : Literal["unicode", "latex", "ascii"]` 360 Output format (defaults to `{default_fmt}`) 361 - `stats : bool` 362 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 363 - `shape : bool` 364 Whether to include shape info (defaults to `True`) 365 - `dtype : bool` 366 Whether to include dtype info (defaults to `True`) 367 - `device : bool` 368 Whether to include device info for torch tensors (defaults to `True`) 369 - `requires_grad : bool` 370 Whether to include requires_grad info for torch tensors (defaults to `True`) 371 - `sparkline : bool` 372 Whether to include a sparkline visualization (defaults to `False`) 373 - `sparkline_width : int` 374 Width of the sparkline (defaults to `20`) 375 - `sparkline_logy : bool` 376 Whether to use logarithmic y-scale for sparkline (defaults to `False`) 377 - `colored : bool` 378 Whether to add color to output (defaults to `False`) 379 - `as_list : bool` 380 Whether to return as list of strings instead of joined string (defaults to `False`) 381 382 # Returns: 383 - `Union[str, List[str]]` 384 Formatted statistical summary, either as string or list of strings 385 """ 386 if fmt is _USE_DEFAULT: 387 fmt = DEFAULT_SETTINGS["fmt"] 388 if precision is _USE_DEFAULT: 389 precision = DEFAULT_SETTINGS["precision"] 390 if stats is _USE_DEFAULT: 391 stats = DEFAULT_SETTINGS["stats"] 392 if shape is _USE_DEFAULT: 393 shape = DEFAULT_SETTINGS["shape"] 394 if dtype is _USE_DEFAULT: 395 dtype = DEFAULT_SETTINGS["dtype"] 396 if device is _USE_DEFAULT: 397 device = DEFAULT_SETTINGS["device"] 398 if requires_grad is _USE_DEFAULT: 399 requires_grad = DEFAULT_SETTINGS["requires_grad"] 400 if sparkline is _USE_DEFAULT: 401 sparkline = DEFAULT_SETTINGS["sparkline"] 402 if sparkline_bins is _USE_DEFAULT: 403 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 404 if sparkline_logy is _USE_DEFAULT: 405 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 406 if colored is _USE_DEFAULT: 407 colored = DEFAULT_SETTINGS["colored"] 408 if as_list is _USE_DEFAULT: 409 as_list = DEFAULT_SETTINGS["as_list"] 410 if eq_char is _USE_DEFAULT: 411 eq_char = DEFAULT_SETTINGS["eq_char"] 412 413 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 414 result_parts: List[str] = [] 415 using_tex: bool = fmt == "latex" 416 417 # Set color scheme based on format and colored flag 418 colors: Dict[str, str] 419 if colored: 420 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 421 else: 422 colors = COLORS["none"] 423 424 # Get symbols for the current format 425 symbols: Dict[str, str] = SYMBOLS[fmt] 426 427 # Helper function to colorize text 428 def colorize(text: str, color_key: str) -> str: 429 if using_tex: 430 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 431 else: 432 return ( 433 f"{colors[color_key]}{text}{colors['reset']}" 434 if colors[color_key] 435 else text 436 ) 437 438 # Format string for numbers 439 float_fmt: str = f".{precision}f" 440 441 # Handle error status or empty array 442 if ( 443 array_data["status"] in ["empty array", "all NaN", "unknown"] 444 or array_data["size"] == 0 445 ): 446 status = array_data["status"] 447 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 448 else: 449 # Add NaN warning at the beginning if there are NaNs 450 if array_data["has_nans"]: 451 _percent: str = "\\%" if using_tex else "%" 452 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 453 result_parts.append(colorize(nan_str, "warning")) 454 455 # Statistics 456 if stats: 457 for stat_key in ["mean", "std", "median"]: 458 if array_data[stat_key] is not None: 459 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 460 stat_colored: str = colorize(stat_str, stat_key) 461 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 462 463 # Range (min, max) 464 if array_data["range"] is not None: 465 min_val, max_val = array_data["range"] 466 min_str: str = f"{min_val:{float_fmt}}" 467 max_str: str = f"{max_val:{float_fmt}}" 468 min_colored: str = colorize(min_str, "range") 469 max_colored: str = colorize(max_str, "range") 470 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 471 result_parts.append(range_str) 472 473 # Add sparkline if requested 474 if sparkline and array_data["histogram"] is not None: 475 spark = generate_sparkline( 476 array_data["histogram"], format=fmt, log_y=sparkline_logy 477 ) 478 if spark: 479 spark_colored = colorize(spark, "sparkline") 480 result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|") 481 482 # Add shape if requested 483 if shape and array_data["shape"]: 484 shape_val = array_data["shape"] 485 if len(shape_val) == 1: 486 shape_str = str(shape_val[0]) 487 else: 488 shape_str = ( 489 "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")" 490 ) 491 result_parts.append(f"shape{eq_char}{shape_str}") 492 493 # Add dtype if requested 494 if dtype and array_data["dtype"]: 495 result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype")) 496 497 # Add device if requested and it's a tensor with device info 498 if device and array_data["is_tensor"] and array_data["device"]: 499 result_parts.append( 500 colorize(f"device{eq_char}{array_data['device']}", "device") 501 ) 502 503 # Add gradient info 504 if requires_grad and array_data["is_tensor"]: 505 bool_req_grad_symb: str = ( 506 symbols["true"] if array_data["requires_grad"] else symbols["false"] 507 ) 508 result_parts.append( 509 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 510 ) 511 512 # Return as list if requested, otherwise join with spaces 513 if as_list: 514 return result_parts 515 else: 516 joinchar: str = r" \quad " if using_tex else " " 517 return joinchar.join(result_parts)
Format array information into a readable summary.
Parameters:
array
array-like object (numpy array or torch tensor)precision : int
Decimal places (defaults to2
)format : Literal["unicode", "latex", "ascii"]
Output format (defaults to{default_fmt}
)stats : bool
Whether to include statistical info (μ, σ, x̃) (defaults toTrue
)shape : bool
Whether to include shape info (defaults toTrue
)dtype : bool
Whether to include dtype info (defaults toTrue
)device : bool
Whether to include device info for torch tensors (defaults toTrue
)requires_grad : bool
Whether to include requires_grad info for torch tensors (defaults toTrue
)sparkline : bool
Whether to include a sparkline visualization (defaults toFalse
)sparkline_width : int
Width of the sparkline (defaults to20
)sparkline_logy : bool
Whether to use logarithmic y-scale for sparkline (defaults toFalse
)colored : bool
Whether to add color to output (defaults toFalse
)as_list : bool
Whether to return as list of strings instead of joined string (defaults toFalse
)
Returns:
Union[str, List[str]]
Formatted statistical summary, either as string or list of strings