Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/scipy/linalg/decomp_lu.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""LU decomposition functions."""
3from warnings import warn
5from numpy import asarray, asarray_chkfinite
7# Local imports
8from .misc import _datacopied, LinAlgWarning
9from .lapack import get_lapack_funcs
10from .flinalg import get_flinalg_funcs
12__all__ = ['lu', 'lu_solve', 'lu_factor']
15def lu_factor(a, overwrite_a=False, check_finite=True):
16 """
17 Compute pivoted LU decomposition of a matrix.
19 The decomposition is::
21 A = P L U
23 where P is a permutation matrix, L lower triangular with unit
24 diagonal elements, and U upper triangular.
26 Parameters
27 ----------
28 a : (M, M) array_like
29 Matrix to decompose
30 overwrite_a : bool, optional
31 Whether to overwrite data in A (may increase performance)
32 check_finite : bool, optional
33 Whether to check that the input matrix contains only finite numbers.
34 Disabling may give a performance gain, but may result in problems
35 (crashes, non-termination) if the inputs do contain infinities or NaNs.
37 Returns
38 -------
39 lu : (N, N) ndarray
40 Matrix containing U in its upper triangle, and L in its lower triangle.
41 The unit diagonal elements of L are not stored.
42 piv : (N,) ndarray
43 Pivot indices representing the permutation matrix P:
44 row i of matrix was interchanged with row piv[i].
46 See also
47 --------
48 lu_solve : solve an equation system using the LU factorization of a matrix
50 Notes
51 -----
52 This is a wrapper to the ``*GETRF`` routines from LAPACK.
54 Examples
55 --------
56 >>> from scipy.linalg import lu_factor
57 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
58 >>> lu, piv = lu_factor(A)
59 >>> piv
60 array([2, 2, 3, 3], dtype=int32)
62 Convert LAPACK's ``piv`` array to NumPy index and test the permutation
64 >>> piv_py = [2, 0, 3, 1]
65 >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
66 >>> np.allclose(A[piv_py] - L @ U, np.zeros((4, 4)))
67 True
68 """
69 if check_finite:
70 a1 = asarray_chkfinite(a)
71 else:
72 a1 = asarray(a)
73 if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
74 raise ValueError('expected square matrix')
75 overwrite_a = overwrite_a or (_datacopied(a1, a))
76 getrf, = get_lapack_funcs(('getrf',), (a1,))
77 lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
78 if info < 0:
79 raise ValueError('illegal value in %dth argument of '
80 'internal getrf (lu_factor)' % -info)
81 if info > 0:
82 warn("Diagonal number %d is exactly zero. Singular matrix." % info,
83 LinAlgWarning, stacklevel=2)
84 return lu, piv
87def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
88 """Solve an equation system, a x = b, given the LU factorization of a
90 Parameters
91 ----------
92 (lu, piv)
93 Factorization of the coefficient matrix a, as given by lu_factor
94 b : array
95 Right-hand side
96 trans : {0, 1, 2}, optional
97 Type of system to solve:
99 ===== =========
100 trans system
101 ===== =========
102 0 a x = b
103 1 a^T x = b
104 2 a^H x = b
105 ===== =========
106 overwrite_b : bool, optional
107 Whether to overwrite data in b (may increase performance)
108 check_finite : bool, optional
109 Whether to check that the input matrices contain only finite numbers.
110 Disabling may give a performance gain, but may result in problems
111 (crashes, non-termination) if the inputs do contain infinities or NaNs.
113 Returns
114 -------
115 x : array
116 Solution to the system
118 See also
119 --------
120 lu_factor : LU factorize a matrix
122 Examples
123 --------
124 >>> from scipy.linalg import lu_factor, lu_solve
125 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
126 >>> b = np.array([1, 1, 1, 1])
127 >>> lu, piv = lu_factor(A)
128 >>> x = lu_solve((lu, piv), b)
129 >>> np.allclose(A @ x - b, np.zeros((4,)))
130 True
132 """
133 (lu, piv) = lu_and_piv
134 if check_finite:
135 b1 = asarray_chkfinite(b)
136 else:
137 b1 = asarray(b)
138 overwrite_b = overwrite_b or _datacopied(b1, b)
139 if lu.shape[0] != b1.shape[0]:
140 raise ValueError("incompatible dimensions.")
142 getrs, = get_lapack_funcs(('getrs',), (lu, b1))
143 x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
144 if info == 0:
145 return x
146 raise ValueError('illegal value in %dth argument of internal gesv|posv'
147 % -info)
150def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
151 """
152 Compute pivoted LU decomposition of a matrix.
154 The decomposition is::
156 A = P L U
158 where P is a permutation matrix, L lower triangular with unit
159 diagonal elements, and U upper triangular.
161 Parameters
162 ----------
163 a : (M, N) array_like
164 Array to decompose
165 permute_l : bool, optional
166 Perform the multiplication P*L (Default: do not permute)
167 overwrite_a : bool, optional
168 Whether to overwrite data in a (may improve performance)
169 check_finite : bool, optional
170 Whether to check that the input matrix contains only finite numbers.
171 Disabling may give a performance gain, but may result in problems
172 (crashes, non-termination) if the inputs do contain infinities or NaNs.
174 Returns
175 -------
176 **(If permute_l == False)**
178 p : (M, M) ndarray
179 Permutation matrix
180 l : (M, K) ndarray
181 Lower triangular or trapezoidal matrix with unit diagonal.
182 K = min(M, N)
183 u : (K, N) ndarray
184 Upper triangular or trapezoidal matrix
186 **(If permute_l == True)**
188 pl : (M, K) ndarray
189 Permuted L matrix.
190 K = min(M, N)
191 u : (K, N) ndarray
192 Upper triangular or trapezoidal matrix
194 Notes
195 -----
196 This is a LU factorization routine written for SciPy.
198 Examples
199 --------
200 >>> from scipy.linalg import lu
201 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
202 >>> p, l, u = lu(A)
203 >>> np.allclose(A - p @ l @ u, np.zeros((4, 4)))
204 True
206 """
207 if check_finite:
208 a1 = asarray_chkfinite(a)
209 else:
210 a1 = asarray(a)
211 if len(a1.shape) != 2:
212 raise ValueError('expected matrix')
213 overwrite_a = overwrite_a or (_datacopied(a1, a))
214 flu, = get_flinalg_funcs(('lu',), (a1,))
215 p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
216 if info < 0:
217 raise ValueError('illegal value in %dth argument of '
218 'internal lu.getrf' % -info)
219 if permute_l:
220 return l, u
221 return p, l, u