docs for muutils v0.8.8
View Source on GitHub

muutils.math.matrix_powers


  1from __future__ import annotations
  2
  3from typing import List, Sequence, TYPE_CHECKING
  4
  5import numpy as np
  6from jaxtyping import Float, Int
  7
  8if TYPE_CHECKING:
  9    pass
 10
 11
 12def matrix_powers(
 13    A: Float[np.ndarray, "n n"],
 14    powers: Sequence[int],
 15) -> Float[np.ndarray, "n_powers n n"]:
 16    """Compute multiple powers of a matrix efficiently.
 17
 18    Uses binary exponentiation to compute powers in O(log max(powers))
 19    matrix multiplications, avoiding redundant calculations when
 20    computing multiple powers.
 21
 22    # Parameters:
 23     - `A : Float[np.ndarray, "n n"]`
 24            Square matrix to exponentiate
 25     - `powers : Sequence[int]`
 26            List of powers to compute (non-negative integers)
 27
 28    # Returns:
 29     - `dict[int, Float[np.ndarray, "n n"]]`
 30            Dictionary mapping each requested power to the corresponding matrix power
 31    """
 32    dim_n: int = A.shape[0]
 33    assert A.shape[0] == A.shape[1], f"Matrix must be square, but got {A.shape = }"
 34    powers_np: Int[np.ndarray, "n_powers_unique"] = np.array(
 35        sorted(set(powers)), dtype=int
 36    )
 37    n_powers_unique: int = len(powers_np)
 38
 39    if n_powers_unique < 1:
 40        raise ValueError(f"No powers requested: {powers = }")
 41
 42    output: Float[np.ndarray, "n_powers_unique n n"] = np.full(
 43        (n_powers_unique, dim_n, dim_n),
 44        fill_value=np.nan,
 45        dtype=A.dtype,
 46    )
 47
 48    # Find the maximum power to compute
 49    max_power: int = max(powers_np)
 50
 51    # Precompute all powers of 2 up to the largest power needed
 52    # This forms our basis for binary decomposition
 53    powers_of_two: dict[int, Float[np.ndarray, "n n"]] = {}
 54    powers_of_two[0] = np.eye(dim_n, dtype=A.dtype)
 55    powers_of_two[1] = A.copy()
 56
 57    # Compute powers of 2: A^2, A^4, A^8, ...
 58    p: int = 1
 59    while p < max_power:
 60        if p <= max_power:
 61            A_power_p = powers_of_two[p]
 62            powers_of_two[p * 2] = A_power_p @ A_power_p
 63        p = p * 2
 64
 65    # For each requested power, compute it using the powers of 2
 66    for p_idx, power in enumerate(powers_np):
 67        # Decompose power into sum of powers of 2
 68        temp_result: Float[np.ndarray, "n n"] = powers_of_two[0].copy()
 69        temp_power: int = power
 70        p_temp: int = 1
 71
 72        while temp_power > 0:
 73            if temp_power % 2 == 1:
 74                temp_result = temp_result @ powers_of_two[p_temp]
 75            temp_power = temp_power // 2
 76            p_temp *= 2
 77
 78        output[p_idx] = temp_result
 79
 80    return output
 81
 82
 83# BUG: breaks with integer matrices???
 84# TYPING: jaxtyping hints not working here, separate file for torch implementation?
 85def matrix_powers_torch(
 86    A,  # : Float["torch.Tensor", "n n"],
 87    powers: Sequence[int],
 88):  # Float["torch.Tensor", "n_powers n n"]:
 89    """Compute multiple powers of a matrix efficiently.
 90
 91    Uses binary exponentiation to compute powers in O(log max(powers))
 92    matrix multiplications, avoiding redundant calculations when
 93    computing multiple powers.
 94
 95    # Parameters:
 96     - `A : Float[torch.Tensor, "n n"]`
 97        Square matrix to exponentiate
 98     - `powers : Sequence[int]`
 99        List of powers to compute (non-negative integers)
100
101    # Returns:
102     - `Float[torch.Tensor, "n_powers n n"]`
103        Tensor containing the requested matrix powers stacked along the first dimension
104
105    # Raises:
106     - `ValueError` : If no powers are requested or if A is not a square matrix
107    """
108
109    import torch
110
111    if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
112        raise ValueError(f"Matrix must be square, but got {A.shape = }")
113
114    dim_n: int = A.shape[0]
115    # Get unique powers and sort them
116    unique_powers: List[int] = sorted(set(powers))
117    n_powers_unique: int = len(unique_powers)
118    powers_tensor: Int[torch.Tensor, "n_powers_unique"] = torch.tensor(
119        unique_powers, dtype=torch.int64, device=A.device
120    )
121
122    if n_powers_unique < 1:
123        raise ValueError(f"No powers requested: {powers = }")
124
125    output: Float[torch.Tensor, "n_powers_unique n n"] = torch.full(
126        (n_powers_unique, dim_n, dim_n),
127        float("nan"),
128        dtype=A.dtype,
129        device=A.device,
130    )
131
132    # Find the maximum power to compute
133    max_power: int = int(powers_tensor.max().item())
134
135    # Precompute all powers of 2 up to the largest power needed
136    # This forms our basis for binary decomposition
137    powers_of_two: dict[int, Float[torch.Tensor, "n n"]] = {}
138    powers_of_two[0] = torch.eye(dim_n, dtype=A.dtype, device=A.device)
139    powers_of_two[1] = A.clone()
140
141    # Compute powers of 2: A^2, A^4, A^8, ...
142    p: int = 1
143    while p < max_power:
144        if p <= max_power:
145            A_power_p: Float[torch.Tensor, "n n"] = powers_of_two[p]
146            powers_of_two[p * 2] = A_power_p @ A_power_p
147        p = p * 2
148
149    # For each requested power, compute it using the powers of 2
150    for p_idx, power in enumerate(unique_powers):
151        # Decompose power into sum of powers of 2
152        temp_result: Float[torch.Tensor, "n n"] = powers_of_two[0].clone()
153        temp_power: int = power
154        p_temp: int = 1
155
156        while temp_power > 0:
157            if temp_power % 2 == 1:
158                temp_result = temp_result @ powers_of_two[p_temp]
159            temp_power = temp_power // 2
160            p_temp *= 2
161
162        output[p_idx] = temp_result
163
164    return output

def matrix_powers( A: jaxtyping.Float[ndarray, 'n n'], powers: Sequence[int]) -> jaxtyping.Float[ndarray, 'n_powers n n']:
13def matrix_powers(
14    A: Float[np.ndarray, "n n"],
15    powers: Sequence[int],
16) -> Float[np.ndarray, "n_powers n n"]:
17    """Compute multiple powers of a matrix efficiently.
18
19    Uses binary exponentiation to compute powers in O(log max(powers))
20    matrix multiplications, avoiding redundant calculations when
21    computing multiple powers.
22
23    # Parameters:
24     - `A : Float[np.ndarray, "n n"]`
25            Square matrix to exponentiate
26     - `powers : Sequence[int]`
27            List of powers to compute (non-negative integers)
28
29    # Returns:
30     - `dict[int, Float[np.ndarray, "n n"]]`
31            Dictionary mapping each requested power to the corresponding matrix power
32    """
33    dim_n: int = A.shape[0]
34    assert A.shape[0] == A.shape[1], f"Matrix must be square, but got {A.shape = }"
35    powers_np: Int[np.ndarray, "n_powers_unique"] = np.array(
36        sorted(set(powers)), dtype=int
37    )
38    n_powers_unique: int = len(powers_np)
39
40    if n_powers_unique < 1:
41        raise ValueError(f"No powers requested: {powers = }")
42
43    output: Float[np.ndarray, "n_powers_unique n n"] = np.full(
44        (n_powers_unique, dim_n, dim_n),
45        fill_value=np.nan,
46        dtype=A.dtype,
47    )
48
49    # Find the maximum power to compute
50    max_power: int = max(powers_np)
51
52    # Precompute all powers of 2 up to the largest power needed
53    # This forms our basis for binary decomposition
54    powers_of_two: dict[int, Float[np.ndarray, "n n"]] = {}
55    powers_of_two[0] = np.eye(dim_n, dtype=A.dtype)
56    powers_of_two[1] = A.copy()
57
58    # Compute powers of 2: A^2, A^4, A^8, ...
59    p: int = 1
60    while p < max_power:
61        if p <= max_power:
62            A_power_p = powers_of_two[p]
63            powers_of_two[p * 2] = A_power_p @ A_power_p
64        p = p * 2
65
66    # For each requested power, compute it using the powers of 2
67    for p_idx, power in enumerate(powers_np):
68        # Decompose power into sum of powers of 2
69        temp_result: Float[np.ndarray, "n n"] = powers_of_two[0].copy()
70        temp_power: int = power
71        p_temp: int = 1
72
73        while temp_power > 0:
74            if temp_power % 2 == 1:
75                temp_result = temp_result @ powers_of_two[p_temp]
76            temp_power = temp_power // 2
77            p_temp *= 2
78
79        output[p_idx] = temp_result
80
81    return output

Compute multiple powers of a matrix efficiently.

Uses binary exponentiation to compute powers in O(log max(powers)) matrix multiplications, avoiding redundant calculations when computing multiple powers.

Parameters:

  • A : Float[np.ndarray, "n n"] Square matrix to exponentiate
  • powers : Sequence[int] List of powers to compute (non-negative integers)

Returns:

  • dict[int, Float[np.ndarray, "n n"]] Dictionary mapping each requested power to the corresponding matrix power
def matrix_powers_torch(A, powers: Sequence[int]):
 86def matrix_powers_torch(
 87    A,  # : Float["torch.Tensor", "n n"],
 88    powers: Sequence[int],
 89):  # Float["torch.Tensor", "n_powers n n"]:
 90    """Compute multiple powers of a matrix efficiently.
 91
 92    Uses binary exponentiation to compute powers in O(log max(powers))
 93    matrix multiplications, avoiding redundant calculations when
 94    computing multiple powers.
 95
 96    # Parameters:
 97     - `A : Float[torch.Tensor, "n n"]`
 98        Square matrix to exponentiate
 99     - `powers : Sequence[int]`
100        List of powers to compute (non-negative integers)
101
102    # Returns:
103     - `Float[torch.Tensor, "n_powers n n"]`
104        Tensor containing the requested matrix powers stacked along the first dimension
105
106    # Raises:
107     - `ValueError` : If no powers are requested or if A is not a square matrix
108    """
109
110    import torch
111
112    if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
113        raise ValueError(f"Matrix must be square, but got {A.shape = }")
114
115    dim_n: int = A.shape[0]
116    # Get unique powers and sort them
117    unique_powers: List[int] = sorted(set(powers))
118    n_powers_unique: int = len(unique_powers)
119    powers_tensor: Int[torch.Tensor, "n_powers_unique"] = torch.tensor(
120        unique_powers, dtype=torch.int64, device=A.device
121    )
122
123    if n_powers_unique < 1:
124        raise ValueError(f"No powers requested: {powers = }")
125
126    output: Float[torch.Tensor, "n_powers_unique n n"] = torch.full(
127        (n_powers_unique, dim_n, dim_n),
128        float("nan"),
129        dtype=A.dtype,
130        device=A.device,
131    )
132
133    # Find the maximum power to compute
134    max_power: int = int(powers_tensor.max().item())
135
136    # Precompute all powers of 2 up to the largest power needed
137    # This forms our basis for binary decomposition
138    powers_of_two: dict[int, Float[torch.Tensor, "n n"]] = {}
139    powers_of_two[0] = torch.eye(dim_n, dtype=A.dtype, device=A.device)
140    powers_of_two[1] = A.clone()
141
142    # Compute powers of 2: A^2, A^4, A^8, ...
143    p: int = 1
144    while p < max_power:
145        if p <= max_power:
146            A_power_p: Float[torch.Tensor, "n n"] = powers_of_two[p]
147            powers_of_two[p * 2] = A_power_p @ A_power_p
148        p = p * 2
149
150    # For each requested power, compute it using the powers of 2
151    for p_idx, power in enumerate(unique_powers):
152        # Decompose power into sum of powers of 2
153        temp_result: Float[torch.Tensor, "n n"] = powers_of_two[0].clone()
154        temp_power: int = power
155        p_temp: int = 1
156
157        while temp_power > 0:
158            if temp_power % 2 == 1:
159                temp_result = temp_result @ powers_of_two[p_temp]
160            temp_power = temp_power // 2
161            p_temp *= 2
162
163        output[p_idx] = temp_result
164
165    return output

Compute multiple powers of a matrix efficiently.

Uses binary exponentiation to compute powers in O(log max(powers)) matrix multiplications, avoiding redundant calculations when computing multiple powers.

Parameters:

  • A : Float[torch.Tensor, "n n"] Square matrix to exponentiate
  • powers : Sequence[int] List of powers to compute (non-negative integers)

Returns:

  • Float[torch.Tensor, "n_powers n n"] Tensor containing the requested matrix powers stacked along the first dimension

Raises:

  • ValueError : If no powers are requested or if A is not a square matrix