Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""LU decomposition functions.""" 

2 

3from warnings import warn 

4 

5from numpy import asarray, asarray_chkfinite 

6 

7# Local imports 

8from .misc import _datacopied, LinAlgWarning 

9from .lapack import get_lapack_funcs 

10from .flinalg import get_flinalg_funcs 

11 

12__all__ = ['lu', 'lu_solve', 'lu_factor'] 

13 

14 

15def lu_factor(a, overwrite_a=False, check_finite=True): 

16 """ 

17 Compute pivoted LU decomposition of a matrix. 

18 

19 The decomposition is:: 

20 

21 A = P L U 

22 

23 where P is a permutation matrix, L lower triangular with unit 

24 diagonal elements, and U upper triangular. 

25 

26 Parameters 

27 ---------- 

28 a : (M, M) array_like 

29 Matrix to decompose 

30 overwrite_a : bool, optional 

31 Whether to overwrite data in A (may increase performance) 

32 check_finite : bool, optional 

33 Whether to check that the input matrix contains only finite numbers. 

34 Disabling may give a performance gain, but may result in problems 

35 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

36 

37 Returns 

38 ------- 

39 lu : (N, N) ndarray 

40 Matrix containing U in its upper triangle, and L in its lower triangle. 

41 The unit diagonal elements of L are not stored. 

42 piv : (N,) ndarray 

43 Pivot indices representing the permutation matrix P: 

44 row i of matrix was interchanged with row piv[i]. 

45 

46 See also 

47 -------- 

48 lu_solve : solve an equation system using the LU factorization of a matrix 

49 

50 Notes 

51 ----- 

52 This is a wrapper to the ``*GETRF`` routines from LAPACK. 

53 

54 Examples 

55 -------- 

56 >>> from scipy.linalg import lu_factor 

57 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

58 >>> lu, piv = lu_factor(A) 

59 >>> piv 

60 array([2, 2, 3, 3], dtype=int32) 

61 

62 Convert LAPACK's ``piv`` array to NumPy index and test the permutation 

63 

64 >>> piv_py = [2, 0, 3, 1] 

65 >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu) 

66 >>> np.allclose(A[piv_py] - L @ U, np.zeros((4, 4))) 

67 True 

68 """ 

69 if check_finite: 

70 a1 = asarray_chkfinite(a) 

71 else: 

72 a1 = asarray(a) 

73 if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]): 

74 raise ValueError('expected square matrix') 

75 overwrite_a = overwrite_a or (_datacopied(a1, a)) 

76 getrf, = get_lapack_funcs(('getrf',), (a1,)) 

77 lu, piv, info = getrf(a1, overwrite_a=overwrite_a) 

78 if info < 0: 

79 raise ValueError('illegal value in %dth argument of ' 

80 'internal getrf (lu_factor)' % -info) 

81 if info > 0: 

82 warn("Diagonal number %d is exactly zero. Singular matrix." % info, 

83 LinAlgWarning, stacklevel=2) 

84 return lu, piv 

85 

86 

87def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): 

88 """Solve an equation system, a x = b, given the LU factorization of a 

89 

90 Parameters 

91 ---------- 

92 (lu, piv) 

93 Factorization of the coefficient matrix a, as given by lu_factor 

94 b : array 

95 Right-hand side 

96 trans : {0, 1, 2}, optional 

97 Type of system to solve: 

98 

99 ===== ========= 

100 trans system 

101 ===== ========= 

102 0 a x = b 

103 1 a^T x = b 

104 2 a^H x = b 

105 ===== ========= 

106 overwrite_b : bool, optional 

107 Whether to overwrite data in b (may increase performance) 

108 check_finite : bool, optional 

109 Whether to check that the input matrices contain only finite numbers. 

110 Disabling may give a performance gain, but may result in problems 

111 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

112 

113 Returns 

114 ------- 

115 x : array 

116 Solution to the system 

117 

118 See also 

119 -------- 

120 lu_factor : LU factorize a matrix 

121 

122 Examples 

123 -------- 

124 >>> from scipy.linalg import lu_factor, lu_solve 

125 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

126 >>> b = np.array([1, 1, 1, 1]) 

127 >>> lu, piv = lu_factor(A) 

128 >>> x = lu_solve((lu, piv), b) 

129 >>> np.allclose(A @ x - b, np.zeros((4,))) 

130 True 

131 

132 """ 

133 (lu, piv) = lu_and_piv 

134 if check_finite: 

135 b1 = asarray_chkfinite(b) 

136 else: 

137 b1 = asarray(b) 

138 overwrite_b = overwrite_b or _datacopied(b1, b) 

139 if lu.shape[0] != b1.shape[0]: 

140 raise ValueError("incompatible dimensions.") 

141 

142 getrs, = get_lapack_funcs(('getrs',), (lu, b1)) 

143 x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b) 

144 if info == 0: 

145 return x 

146 raise ValueError('illegal value in %dth argument of internal gesv|posv' 

147 % -info) 

148 

149 

150def lu(a, permute_l=False, overwrite_a=False, check_finite=True): 

151 """ 

152 Compute pivoted LU decomposition of a matrix. 

153 

154 The decomposition is:: 

155 

156 A = P L U 

157 

158 where P is a permutation matrix, L lower triangular with unit 

159 diagonal elements, and U upper triangular. 

160 

161 Parameters 

162 ---------- 

163 a : (M, N) array_like 

164 Array to decompose 

165 permute_l : bool, optional 

166 Perform the multiplication P*L (Default: do not permute) 

167 overwrite_a : bool, optional 

168 Whether to overwrite data in a (may improve performance) 

169 check_finite : bool, optional 

170 Whether to check that the input matrix contains only finite numbers. 

171 Disabling may give a performance gain, but may result in problems 

172 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

173 

174 Returns 

175 ------- 

176 **(If permute_l == False)** 

177 

178 p : (M, M) ndarray 

179 Permutation matrix 

180 l : (M, K) ndarray 

181 Lower triangular or trapezoidal matrix with unit diagonal. 

182 K = min(M, N) 

183 u : (K, N) ndarray 

184 Upper triangular or trapezoidal matrix 

185 

186 **(If permute_l == True)** 

187 

188 pl : (M, K) ndarray 

189 Permuted L matrix. 

190 K = min(M, N) 

191 u : (K, N) ndarray 

192 Upper triangular or trapezoidal matrix 

193 

194 Notes 

195 ----- 

196 This is a LU factorization routine written for SciPy. 

197 

198 Examples 

199 -------- 

200 >>> from scipy.linalg import lu 

201 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

202 >>> p, l, u = lu(A) 

203 >>> np.allclose(A - p @ l @ u, np.zeros((4, 4))) 

204 True 

205 

206 """ 

207 if check_finite: 

208 a1 = asarray_chkfinite(a) 

209 else: 

210 a1 = asarray(a) 

211 if len(a1.shape) != 2: 

212 raise ValueError('expected matrix') 

213 overwrite_a = overwrite_a or (_datacopied(a1, a)) 

214 flu, = get_flinalg_funcs(('lu',), (a1,)) 

215 p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a) 

216 if info < 0: 

217 raise ValueError('illegal value in %dth argument of ' 

218 'internal lu.getrf' % -info) 

219 if permute_l: 

220 return l, u 

221 return p, l, u