Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# 

2# Author: Travis Oliphant, March 2002 

3# 

4 

5__all__ = ['expm','cosm','sinm','tanm','coshm','sinhm', 

6 'tanhm','logm','funm','signm','sqrtm', 

7 'expm_frechet', 'expm_cond', 'fractional_matrix_power', 

8 'khatri_rao'] 

9 

10from numpy import (Inf, dot, diag, prod, logical_not, ravel, 

11 transpose, conjugate, absolute, amax, sign, isfinite, single) 

12import numpy as np 

13 

14# Local imports 

15from .misc import norm 

16from .basic import solve, inv 

17from .special_matrices import triu 

18from .decomp_svd import svd 

19from .decomp_schur import schur, rsf2csf 

20from ._expm_frechet import expm_frechet, expm_cond 

21from ._matfuncs_sqrtm import sqrtm 

22 

23eps = np.finfo(float).eps 

24feps = np.finfo(single).eps 

25 

26_array_precision = {'i': 1, 'l': 1, 'f': 0, 'd': 1, 'F': 0, 'D': 1} 

27 

28 

29############################################################################### 

30# Utility functions. 

31 

32 

33def _asarray_square(A): 

34 """ 

35 Wraps asarray with the extra requirement that the input be a square matrix. 

36 

37 The motivation is that the matfuncs module has real functions that have 

38 been lifted to square matrix functions. 

39 

40 Parameters 

41 ---------- 

42 A : array_like 

43 A square matrix. 

44 

45 Returns 

46 ------- 

47 out : ndarray 

48 An ndarray copy or view or other representation of A. 

49 

50 """ 

51 A = np.asarray(A) 

52 if len(A.shape) != 2 or A.shape[0] != A.shape[1]: 

53 raise ValueError('expected square array_like input') 

54 return A 

55 

56 

57def _maybe_real(A, B, tol=None): 

58 """ 

59 Return either B or the real part of B, depending on properties of A and B. 

60 

61 The motivation is that B has been computed as a complicated function of A, 

62 and B may be perturbed by negligible imaginary components. 

63 If A is real and B is complex with small imaginary components, 

64 then return a real copy of B. The assumption in that case would be that 

65 the imaginary components of B are numerical artifacts. 

66 

67 Parameters 

68 ---------- 

69 A : ndarray 

70 Input array whose type is to be checked as real vs. complex. 

71 B : ndarray 

72 Array to be returned, possibly without its imaginary part. 

73 tol : float 

74 Absolute tolerance. 

75 

76 Returns 

77 ------- 

78 out : real or complex array 

79 Either the input array B or only the real part of the input array B. 

80 

81 """ 

82 # Note that booleans and integers compare as real. 

83 if np.isrealobj(A) and np.iscomplexobj(B): 

84 if tol is None: 

85 tol = {0:feps*1e3, 1:eps*1e6}[_array_precision[B.dtype.char]] 

86 if np.allclose(B.imag, 0.0, atol=tol): 

87 B = B.real 

88 return B 

89 

90 

91############################################################################### 

92# Matrix functions. 

93 

94 

95def fractional_matrix_power(A, t): 

96 """ 

97 Compute the fractional power of a matrix. 

98 

99 Proceeds according to the discussion in section (6) of [1]_. 

100 

101 Parameters 

102 ---------- 

103 A : (N, N) array_like 

104 Matrix whose fractional power to evaluate. 

105 t : float 

106 Fractional power. 

107 

108 Returns 

109 ------- 

110 X : (N, N) array_like 

111 The fractional power of the matrix. 

112 

113 References 

114 ---------- 

115 .. [1] Nicholas J. Higham and Lijing lin (2011) 

116 "A Schur-Pade Algorithm for Fractional Powers of a Matrix." 

117 SIAM Journal on Matrix Analysis and Applications, 

118 32 (3). pp. 1056-1078. ISSN 0895-4798 

119 

120 Examples 

121 -------- 

122 >>> from scipy.linalg import fractional_matrix_power 

123 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]]) 

124 >>> b = fractional_matrix_power(a, 0.5) 

125 >>> b 

126 array([[ 0.75592895, 1.13389342], 

127 [ 0.37796447, 1.88982237]]) 

128 >>> np.dot(b, b) # Verify square root 

129 array([[ 1., 3.], 

130 [ 1., 4.]]) 

131 

132 """ 

133 # This fixes some issue with imports; 

134 # this function calls onenormest which is in scipy.sparse. 

135 A = _asarray_square(A) 

136 import scipy.linalg._matfuncs_inv_ssq 

137 return scipy.linalg._matfuncs_inv_ssq._fractional_matrix_power(A, t) 

138 

139 

140def logm(A, disp=True): 

141 """ 

142 Compute matrix logarithm. 

143 

144 The matrix logarithm is the inverse of 

145 expm: expm(logm(`A`)) == `A` 

146 

147 Parameters 

148 ---------- 

149 A : (N, N) array_like 

150 Matrix whose logarithm to evaluate 

151 disp : bool, optional 

152 Print warning if error in the result is estimated large 

153 instead of returning estimated error. (Default: True) 

154 

155 Returns 

156 ------- 

157 logm : (N, N) ndarray 

158 Matrix logarithm of `A` 

159 errest : float 

160 (if disp == False) 

161 

162 1-norm of the estimated error, ||err||_1 / ||A||_1 

163 

164 References 

165 ---------- 

166 .. [1] Awad H. Al-Mohy and Nicholas J. Higham (2012) 

167 "Improved Inverse Scaling and Squaring Algorithms 

168 for the Matrix Logarithm." 

169 SIAM Journal on Scientific Computing, 34 (4). C152-C169. 

170 ISSN 1095-7197 

171 

172 .. [2] Nicholas J. Higham (2008) 

173 "Functions of Matrices: Theory and Computation" 

174 ISBN 978-0-898716-46-7 

175 

176 .. [3] Nicholas J. Higham and Lijing lin (2011) 

177 "A Schur-Pade Algorithm for Fractional Powers of a Matrix." 

178 SIAM Journal on Matrix Analysis and Applications, 

179 32 (3). pp. 1056-1078. ISSN 0895-4798 

180 

181 Examples 

182 -------- 

183 >>> from scipy.linalg import logm, expm 

184 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]]) 

185 >>> b = logm(a) 

186 >>> b 

187 array([[-1.02571087, 2.05142174], 

188 [ 0.68380725, 1.02571087]]) 

189 >>> expm(b) # Verify expm(logm(a)) returns a 

190 array([[ 1., 3.], 

191 [ 1., 4.]]) 

192 

193 """ 

194 A = _asarray_square(A) 

195 # Avoid circular import ... this is OK, right? 

196 import scipy.linalg._matfuncs_inv_ssq 

197 F = scipy.linalg._matfuncs_inv_ssq._logm(A) 

198 F = _maybe_real(A, F) 

199 errtol = 1000*eps 

200 #TODO use a better error approximation 

201 errest = norm(expm(F)-A,1) / norm(A,1) 

202 if disp: 

203 if not isfinite(errest) or errest >= errtol: 

204 print("logm result may be inaccurate, approximate err =", errest) 

205 return F 

206 else: 

207 return F, errest 

208 

209 

210def expm(A): 

211 """ 

212 Compute the matrix exponential using Pade approximation. 

213 

214 Parameters 

215 ---------- 

216 A : (N, N) array_like or sparse matrix 

217 Matrix to be exponentiated. 

218 

219 Returns 

220 ------- 

221 expm : (N, N) ndarray 

222 Matrix exponential of `A`. 

223 

224 References 

225 ---------- 

226 .. [1] Awad H. Al-Mohy and Nicholas J. Higham (2009) 

227 "A New Scaling and Squaring Algorithm for the Matrix Exponential." 

228 SIAM Journal on Matrix Analysis and Applications. 

229 31 (3). pp. 970-989. ISSN 1095-7162 

230 

231 Examples 

232 -------- 

233 >>> from scipy.linalg import expm, sinm, cosm 

234 

235 Matrix version of the formula exp(0) = 1: 

236 

237 >>> expm(np.zeros((2,2))) 

238 array([[ 1., 0.], 

239 [ 0., 1.]]) 

240 

241 Euler's identity (exp(i*theta) = cos(theta) + i*sin(theta)) 

242 applied to a matrix: 

243 

244 >>> a = np.array([[1.0, 2.0], [-1.0, 3.0]]) 

245 >>> expm(1j*a) 

246 array([[ 0.42645930+1.89217551j, -2.13721484-0.97811252j], 

247 [ 1.06860742+0.48905626j, -1.71075555+0.91406299j]]) 

248 >>> cosm(a) + 1j*sinm(a) 

249 array([[ 0.42645930+1.89217551j, -2.13721484-0.97811252j], 

250 [ 1.06860742+0.48905626j, -1.71075555+0.91406299j]]) 

251 

252 """ 

253 # Input checking and conversion is provided by sparse.linalg.expm(). 

254 import scipy.sparse.linalg 

255 return scipy.sparse.linalg.expm(A) 

256 

257 

258def cosm(A): 

259 """ 

260 Compute the matrix cosine. 

261 

262 This routine uses expm to compute the matrix exponentials. 

263 

264 Parameters 

265 ---------- 

266 A : (N, N) array_like 

267 Input array 

268 

269 Returns 

270 ------- 

271 cosm : (N, N) ndarray 

272 Matrix cosine of A 

273 

274 Examples 

275 -------- 

276 >>> from scipy.linalg import expm, sinm, cosm 

277 

278 Euler's identity (exp(i*theta) = cos(theta) + i*sin(theta)) 

279 applied to a matrix: 

280 

281 >>> a = np.array([[1.0, 2.0], [-1.0, 3.0]]) 

282 >>> expm(1j*a) 

283 array([[ 0.42645930+1.89217551j, -2.13721484-0.97811252j], 

284 [ 1.06860742+0.48905626j, -1.71075555+0.91406299j]]) 

285 >>> cosm(a) + 1j*sinm(a) 

286 array([[ 0.42645930+1.89217551j, -2.13721484-0.97811252j], 

287 [ 1.06860742+0.48905626j, -1.71075555+0.91406299j]]) 

288 

289 """ 

290 A = _asarray_square(A) 

291 if np.iscomplexobj(A): 

292 return 0.5*(expm(1j*A) + expm(-1j*A)) 

293 else: 

294 return expm(1j*A).real 

295 

296 

297def sinm(A): 

298 """ 

299 Compute the matrix sine. 

300 

301 This routine uses expm to compute the matrix exponentials. 

302 

303 Parameters 

304 ---------- 

305 A : (N, N) array_like 

306 Input array. 

307 

308 Returns 

309 ------- 

310 sinm : (N, N) ndarray 

311 Matrix sine of `A` 

312 

313 Examples 

314 -------- 

315 >>> from scipy.linalg import expm, sinm, cosm 

316 

317 Euler's identity (exp(i*theta) = cos(theta) + i*sin(theta)) 

318 applied to a matrix: 

319 

320 >>> a = np.array([[1.0, 2.0], [-1.0, 3.0]]) 

321 >>> expm(1j*a) 

322 array([[ 0.42645930+1.89217551j, -2.13721484-0.97811252j], 

323 [ 1.06860742+0.48905626j, -1.71075555+0.91406299j]]) 

324 >>> cosm(a) + 1j*sinm(a) 

325 array([[ 0.42645930+1.89217551j, -2.13721484-0.97811252j], 

326 [ 1.06860742+0.48905626j, -1.71075555+0.91406299j]]) 

327 

328 """ 

329 A = _asarray_square(A) 

330 if np.iscomplexobj(A): 

331 return -0.5j*(expm(1j*A) - expm(-1j*A)) 

332 else: 

333 return expm(1j*A).imag 

334 

335 

336def tanm(A): 

337 """ 

338 Compute the matrix tangent. 

339 

340 This routine uses expm to compute the matrix exponentials. 

341 

342 Parameters 

343 ---------- 

344 A : (N, N) array_like 

345 Input array. 

346 

347 Returns 

348 ------- 

349 tanm : (N, N) ndarray 

350 Matrix tangent of `A` 

351 

352 Examples 

353 -------- 

354 >>> from scipy.linalg import tanm, sinm, cosm 

355 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]]) 

356 >>> t = tanm(a) 

357 >>> t 

358 array([[ -2.00876993, -8.41880636], 

359 [ -2.80626879, -10.42757629]]) 

360 

361 Verify tanm(a) = sinm(a).dot(inv(cosm(a))) 

362 

363 >>> s = sinm(a) 

364 >>> c = cosm(a) 

365 >>> s.dot(np.linalg.inv(c)) 

366 array([[ -2.00876993, -8.41880636], 

367 [ -2.80626879, -10.42757629]]) 

368 

369 """ 

370 A = _asarray_square(A) 

371 return _maybe_real(A, solve(cosm(A), sinm(A))) 

372 

373 

374def coshm(A): 

375 """ 

376 Compute the hyperbolic matrix cosine. 

377 

378 This routine uses expm to compute the matrix exponentials. 

379 

380 Parameters 

381 ---------- 

382 A : (N, N) array_like 

383 Input array. 

384 

385 Returns 

386 ------- 

387 coshm : (N, N) ndarray 

388 Hyperbolic matrix cosine of `A` 

389 

390 Examples 

391 -------- 

392 >>> from scipy.linalg import tanhm, sinhm, coshm 

393 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]]) 

394 >>> c = coshm(a) 

395 >>> c 

396 array([[ 11.24592233, 38.76236492], 

397 [ 12.92078831, 50.00828725]]) 

398 

399 Verify tanhm(a) = sinhm(a).dot(inv(coshm(a))) 

400 

401 >>> t = tanhm(a) 

402 >>> s = sinhm(a) 

403 >>> t - s.dot(np.linalg.inv(c)) 

404 array([[ 2.72004641e-15, 4.55191440e-15], 

405 [ 0.00000000e+00, -5.55111512e-16]]) 

406 

407 """ 

408 A = _asarray_square(A) 

409 return _maybe_real(A, 0.5 * (expm(A) + expm(-A))) 

410 

411 

412def sinhm(A): 

413 """ 

414 Compute the hyperbolic matrix sine. 

415 

416 This routine uses expm to compute the matrix exponentials. 

417 

418 Parameters 

419 ---------- 

420 A : (N, N) array_like 

421 Input array. 

422 

423 Returns 

424 ------- 

425 sinhm : (N, N) ndarray 

426 Hyperbolic matrix sine of `A` 

427 

428 Examples 

429 -------- 

430 >>> from scipy.linalg import tanhm, sinhm, coshm 

431 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]]) 

432 >>> s = sinhm(a) 

433 >>> s 

434 array([[ 10.57300653, 39.28826594], 

435 [ 13.09608865, 49.86127247]]) 

436 

437 Verify tanhm(a) = sinhm(a).dot(inv(coshm(a))) 

438 

439 >>> t = tanhm(a) 

440 >>> c = coshm(a) 

441 >>> t - s.dot(np.linalg.inv(c)) 

442 array([[ 2.72004641e-15, 4.55191440e-15], 

443 [ 0.00000000e+00, -5.55111512e-16]]) 

444 

445 """ 

446 A = _asarray_square(A) 

447 return _maybe_real(A, 0.5 * (expm(A) - expm(-A))) 

448 

449 

450def tanhm(A): 

451 """ 

452 Compute the hyperbolic matrix tangent. 

453 

454 This routine uses expm to compute the matrix exponentials. 

455 

456 Parameters 

457 ---------- 

458 A : (N, N) array_like 

459 Input array 

460 

461 Returns 

462 ------- 

463 tanhm : (N, N) ndarray 

464 Hyperbolic matrix tangent of `A` 

465 

466 Examples 

467 -------- 

468 >>> from scipy.linalg import tanhm, sinhm, coshm 

469 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]]) 

470 >>> t = tanhm(a) 

471 >>> t 

472 array([[ 0.3428582 , 0.51987926], 

473 [ 0.17329309, 0.86273746]]) 

474 

475 Verify tanhm(a) = sinhm(a).dot(inv(coshm(a))) 

476 

477 >>> s = sinhm(a) 

478 >>> c = coshm(a) 

479 >>> t - s.dot(np.linalg.inv(c)) 

480 array([[ 2.72004641e-15, 4.55191440e-15], 

481 [ 0.00000000e+00, -5.55111512e-16]]) 

482 

483 """ 

484 A = _asarray_square(A) 

485 return _maybe_real(A, solve(coshm(A), sinhm(A))) 

486 

487 

488def funm(A, func, disp=True): 

489 """ 

490 Evaluate a matrix function specified by a callable. 

491 

492 Returns the value of matrix-valued function ``f`` at `A`. The 

493 function ``f`` is an extension of the scalar-valued function `func` 

494 to matrices. 

495 

496 Parameters 

497 ---------- 

498 A : (N, N) array_like 

499 Matrix at which to evaluate the function 

500 func : callable 

501 Callable object that evaluates a scalar function f. 

502 Must be vectorized (eg. using vectorize). 

503 disp : bool, optional 

504 Print warning if error in the result is estimated large 

505 instead of returning estimated error. (Default: True) 

506 

507 Returns 

508 ------- 

509 funm : (N, N) ndarray 

510 Value of the matrix function specified by func evaluated at `A` 

511 errest : float 

512 (if disp == False) 

513 

514 1-norm of the estimated error, ||err||_1 / ||A||_1 

515 

516 Examples 

517 -------- 

518 >>> from scipy.linalg import funm 

519 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]]) 

520 >>> funm(a, lambda x: x*x) 

521 array([[ 4., 15.], 

522 [ 5., 19.]]) 

523 >>> a.dot(a) 

524 array([[ 4., 15.], 

525 [ 5., 19.]]) 

526 

527 Notes 

528 ----- 

529 This function implements the general algorithm based on Schur decomposition 

530 (Algorithm 9.1.1. in [1]_). 

531 

532 If the input matrix is known to be diagonalizable, then relying on the 

533 eigendecomposition is likely to be faster. For example, if your matrix is 

534 Hermitian, you can do 

535 

536 >>> from scipy.linalg import eigh 

537 >>> def funm_herm(a, func, check_finite=False): 

538 ... w, v = eigh(a, check_finite=check_finite) 

539 ... ## if you further know that your matrix is positive semidefinite, 

540 ... ## you can optionally guard against precision errors by doing 

541 ... # w = np.maximum(w, 0) 

542 ... w = func(w) 

543 ... return (v * w).dot(v.conj().T) 

544 

545 References 

546 ---------- 

547 .. [1] Gene H. Golub, Charles F. van Loan, Matrix Computations 4th ed. 

548 

549 """ 

550 A = _asarray_square(A) 

551 # Perform Shur decomposition (lapack ?gees) 

552 T, Z = schur(A) 

553 T, Z = rsf2csf(T,Z) 

554 n,n = T.shape 

555 F = diag(func(diag(T))) # apply function to diagonal elements 

556 F = F.astype(T.dtype.char) # e.g., when F is real but T is complex 

557 

558 minden = abs(T[0,0]) 

559 

560 # implement Algorithm 11.1.1 from Golub and Van Loan 

561 # "matrix Computations." 

562 for p in range(1,n): 

563 for i in range(1,n-p+1): 

564 j = i + p 

565 s = T[i-1,j-1] * (F[j-1,j-1] - F[i-1,i-1]) 

566 ksl = slice(i,j-1) 

567 val = dot(T[i-1,ksl],F[ksl,j-1]) - dot(F[i-1,ksl],T[ksl,j-1]) 

568 s = s + val 

569 den = T[j-1,j-1] - T[i-1,i-1] 

570 if den != 0.0: 

571 s = s / den 

572 F[i-1,j-1] = s 

573 minden = min(minden,abs(den)) 

574 

575 F = dot(dot(Z, F), transpose(conjugate(Z))) 

576 F = _maybe_real(A, F) 

577 

578 tol = {0:feps, 1:eps}[_array_precision[F.dtype.char]] 

579 if minden == 0.0: 

580 minden = tol 

581 err = min(1, max(tol,(tol/minden)*norm(triu(T,1),1))) 

582 if prod(ravel(logical_not(isfinite(F))),axis=0): 

583 err = Inf 

584 if disp: 

585 if err > 1000*tol: 

586 print("funm result may be inaccurate, approximate err =", err) 

587 return F 

588 else: 

589 return F, err 

590 

591 

592def signm(A, disp=True): 

593 """ 

594 Matrix sign function. 

595 

596 Extension of the scalar sign(x) to matrices. 

597 

598 Parameters 

599 ---------- 

600 A : (N, N) array_like 

601 Matrix at which to evaluate the sign function 

602 disp : bool, optional 

603 Print warning if error in the result is estimated large 

604 instead of returning estimated error. (Default: True) 

605 

606 Returns 

607 ------- 

608 signm : (N, N) ndarray 

609 Value of the sign function at `A` 

610 errest : float 

611 (if disp == False) 

612 

613 1-norm of the estimated error, ||err||_1 / ||A||_1 

614 

615 Examples 

616 -------- 

617 >>> from scipy.linalg import signm, eigvals 

618 >>> a = [[1,2,3], [1,2,1], [1,1,1]] 

619 >>> eigvals(a) 

620 array([ 4.12488542+0.j, -0.76155718+0.j, 0.63667176+0.j]) 

621 >>> eigvals(signm(a)) 

622 array([-1.+0.j, 1.+0.j, 1.+0.j]) 

623 

624 """ 

625 A = _asarray_square(A) 

626 

627 def rounded_sign(x): 

628 rx = np.real(x) 

629 if rx.dtype.char == 'f': 

630 c = 1e3*feps*amax(x) 

631 else: 

632 c = 1e3*eps*amax(x) 

633 return sign((absolute(rx) > c) * rx) 

634 result, errest = funm(A, rounded_sign, disp=0) 

635 errtol = {0:1e3*feps, 1:1e3*eps}[_array_precision[result.dtype.char]] 

636 if errest < errtol: 

637 return result 

638 

639 # Handle signm of defective matrices: 

640 

641 # See "E.D.Denman and J.Leyva-Ramos, Appl.Math.Comp., 

642 # 8:237-250,1981" for how to improve the following (currently a 

643 # rather naive) iteration process: 

644 

645 # a = result # sometimes iteration converges faster but where?? 

646 

647 # Shifting to avoid zero eigenvalues. How to ensure that shifting does 

648 # not change the spectrum too much? 

649 vals = svd(A, compute_uv=0) 

650 max_sv = np.amax(vals) 

651 # min_nonzero_sv = vals[(vals>max_sv*errtol).tolist().count(1)-1] 

652 # c = 0.5/min_nonzero_sv 

653 c = 0.5/max_sv 

654 S0 = A + c*np.identity(A.shape[0]) 

655 prev_errest = errest 

656 for i in range(100): 

657 iS0 = inv(S0) 

658 S0 = 0.5*(S0 + iS0) 

659 Pp = 0.5*(dot(S0,S0)+S0) 

660 errest = norm(dot(Pp,Pp)-Pp,1) 

661 if errest < errtol or prev_errest == errest: 

662 break 

663 prev_errest = errest 

664 if disp: 

665 if not isfinite(errest) or errest >= errtol: 

666 print("signm result may be inaccurate, approximate err =", errest) 

667 return S0 

668 else: 

669 return S0, errest 

670 

671 

672def khatri_rao(a, b): 

673 r""" 

674 Khatri-rao product 

675 

676 A column-wise Kronecker product of two matrices 

677 

678 Parameters 

679 ---------- 

680 a: (n, k) array_like 

681 Input array 

682 b: (m, k) array_like 

683 Input array 

684 

685 Returns 

686 ------- 

687 c: (n*m, k) ndarray 

688 Khatri-rao product of `a` and `b`. 

689 

690 Notes 

691 ----- 

692 The mathematical definition of the Khatri-Rao product is: 

693 

694 .. math:: 

695 

696 (A_{ij} \bigotimes B_{ij})_{ij} 

697 

698 which is the Kronecker product of every column of A and B, e.g.:: 

699 

700 c = np.vstack([np.kron(a[:, k], b[:, k]) for k in range(b.shape[1])]).T 

701 

702 See Also 

703 -------- 

704 kron : Kronecker product 

705 

706 Examples 

707 -------- 

708 >>> from scipy import linalg 

709 >>> a = np.array([[1, 2, 3], [4, 5, 6]]) 

710 >>> b = np.array([[3, 4, 5], [6, 7, 8], [2, 3, 9]]) 

711 >>> linalg.khatri_rao(a, b) 

712 array([[ 3, 8, 15], 

713 [ 6, 14, 24], 

714 [ 2, 6, 27], 

715 [12, 20, 30], 

716 [24, 35, 48], 

717 [ 8, 15, 54]]) 

718 

719 """ 

720 a = np.asarray(a) 

721 b = np.asarray(b) 

722 

723 if not(a.ndim == 2 and b.ndim == 2): 

724 raise ValueError("The both arrays should be 2-dimensional.") 

725 

726 if not a.shape[1] == b.shape[1]: 

727 raise ValueError("The number of columns for both arrays " 

728 "should be equal.") 

729 

730 # c = np.vstack([np.kron(a[:, k], b[:, k]) for k in range(b.shape[1])]).T 

731 c = a[..., :, np.newaxis, :] * b[..., np.newaxis, :, :] 

732 return c.reshape((-1,) + c.shape[2:])