Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Indexing mixin for sparse matrix classes. 

2""" 

3import numpy as np 

4from .sputils import isintlike 

5 

6try: 

7 INT_TYPES = (int, long, np.integer) 

8except NameError: 

9 # long is not defined in Python3 

10 INT_TYPES = (int, np.integer) 

11 

12 

13def _broadcast_arrays(a, b): 

14 """ 

15 Same as np.broadcast_arrays(a, b) but old writeability rules. 

16 

17 NumPy >= 1.17.0 transitions broadcast_arrays to return 

18 read-only arrays. Set writeability explicitly to avoid warnings. 

19 Retain the old writeability rules, as our Cython code assumes 

20 the old behavior. 

21 """ 

22 x, y = np.broadcast_arrays(a, b) 

23 x.flags.writeable = a.flags.writeable 

24 y.flags.writeable = b.flags.writeable 

25 return x, y 

26 

27 

28class IndexMixin(object): 

29 """ 

30 This class provides common dispatching and validation logic for indexing. 

31 """ 

32 def __getitem__(self, key): 

33 row, col = self._validate_indices(key) 

34 # Dispatch to specialized methods. 

35 if isinstance(row, INT_TYPES): 

36 if isinstance(col, INT_TYPES): 

37 return self._get_intXint(row, col) 

38 elif isinstance(col, slice): 

39 return self._get_intXslice(row, col) 

40 elif col.ndim == 1: 

41 return self._get_intXarray(row, col) 

42 raise IndexError('index results in >2 dimensions') 

43 elif isinstance(row, slice): 

44 if isinstance(col, INT_TYPES): 

45 return self._get_sliceXint(row, col) 

46 elif isinstance(col, slice): 

47 if row == slice(None) and row == col: 

48 return self.copy() 

49 return self._get_sliceXslice(row, col) 

50 elif col.ndim == 1: 

51 return self._get_sliceXarray(row, col) 

52 raise IndexError('index results in >2 dimensions') 

53 elif row.ndim == 1: 

54 if isinstance(col, INT_TYPES): 

55 return self._get_arrayXint(row, col) 

56 elif isinstance(col, slice): 

57 return self._get_arrayXslice(row, col) 

58 else: # row.ndim == 2 

59 if isinstance(col, INT_TYPES): 

60 return self._get_arrayXint(row, col) 

61 elif isinstance(col, slice): 

62 raise IndexError('index results in >2 dimensions') 

63 elif row.shape[1] == 1 and (col.ndim == 1 or col.shape[0] == 1): 

64 # special case for outer indexing 

65 return self._get_columnXarray(row[:,0], col.ravel()) 

66 

67 # The only remaining case is inner (fancy) indexing 

68 row, col = _broadcast_arrays(row, col) 

69 if row.shape != col.shape: 

70 raise IndexError('number of row and column indices differ') 

71 if row.size == 0: 

72 return self.__class__(np.atleast_2d(row).shape, dtype=self.dtype) 

73 return self._get_arrayXarray(row, col) 

74 

75 def __setitem__(self, key, x): 

76 row, col = self._validate_indices(key) 

77 

78 if isinstance(row, INT_TYPES) and isinstance(col, INT_TYPES): 

79 x = np.asarray(x, dtype=self.dtype) 

80 if x.size != 1: 

81 raise ValueError('Trying to assign a sequence to an item') 

82 self._set_intXint(row, col, x.flat[0]) 

83 return 

84 

85 if isinstance(row, slice): 

86 row = np.arange(*row.indices(self.shape[0]))[:, None] 

87 else: 

88 row = np.atleast_1d(row) 

89 

90 if isinstance(col, slice): 

91 col = np.arange(*col.indices(self.shape[1]))[None, :] 

92 if row.ndim == 1: 

93 row = row[:, None] 

94 else: 

95 col = np.atleast_1d(col) 

96 

97 i, j = _broadcast_arrays(row, col) 

98 if i.shape != j.shape: 

99 raise IndexError('number of row and column indices differ') 

100 

101 from .base import isspmatrix 

102 if isspmatrix(x): 

103 if i.ndim == 1: 

104 # Inner indexing, so treat them like row vectors. 

105 i = i[None] 

106 j = j[None] 

107 broadcast_row = x.shape[0] == 1 and i.shape[0] != 1 

108 broadcast_col = x.shape[1] == 1 and i.shape[1] != 1 

109 if not ((broadcast_row or x.shape[0] == i.shape[0]) and 

110 (broadcast_col or x.shape[1] == i.shape[1])): 

111 raise ValueError('shape mismatch in assignment') 

112 if x.size == 0: 

113 return 

114 x = x.tocoo(copy=True) 

115 x.sum_duplicates() 

116 self._set_arrayXarray_sparse(i, j, x) 

117 else: 

118 # Make x and i into the same shape 

119 x = np.asarray(x, dtype=self.dtype) 

120 x, _ = _broadcast_arrays(x, i) 

121 if x.size == 0: 

122 return 

123 x = x.reshape(i.shape) 

124 self._set_arrayXarray(i, j, x) 

125 

126 def _validate_indices(self, key): 

127 M, N = self.shape 

128 row, col = _unpack_index(key) 

129 

130 if isintlike(row): 

131 row = int(row) 

132 if row < -M or row >= M: 

133 raise IndexError('row index (%d) out of range' % row) 

134 if row < 0: 

135 row += M 

136 elif not isinstance(row, slice): 

137 row = self._asindices(row, M) 

138 

139 if isintlike(col): 

140 col = int(col) 

141 if col < -N or col >= N: 

142 raise IndexError('column index (%d) out of range' % col) 

143 if col < 0: 

144 col += N 

145 elif not isinstance(col, slice): 

146 col = self._asindices(col, N) 

147 

148 return row, col 

149 

150 def _asindices(self, idx, length): 

151 """Convert `idx` to a valid index for an axis with a given length. 

152 

153 Subclasses that need special validation can override this method. 

154 """ 

155 try: 

156 x = np.asarray(idx) 

157 except (ValueError, TypeError, MemoryError): 

158 raise IndexError('invalid index') 

159 

160 if x.ndim not in (1, 2): 

161 raise IndexError('Index dimension must be <= 2') 

162 

163 if x.size == 0: 

164 return x 

165 

166 # Check bounds 

167 max_indx = x.max() 

168 if max_indx >= length: 

169 raise IndexError('index (%d) out of range' % max_indx) 

170 

171 min_indx = x.min() 

172 if min_indx < 0: 

173 if min_indx < -length: 

174 raise IndexError('index (%d) out of range' % min_indx) 

175 if x is idx or not x.flags.owndata: 

176 x = x.copy() 

177 x[x < 0] += length 

178 return x 

179 

180 def getrow(self, i): 

181 """Return a copy of row i of the matrix, as a (1 x n) row vector. 

182 """ 

183 M, N = self.shape 

184 i = int(i) 

185 if i < -M or i >= M: 

186 raise IndexError('index (%d) out of range' % i) 

187 if i < 0: 

188 i += M 

189 return self._get_intXslice(i, slice(None)) 

190 

191 def getcol(self, i): 

192 """Return a copy of column i of the matrix, as a (m x 1) column vector. 

193 """ 

194 M, N = self.shape 

195 i = int(i) 

196 if i < -N or i >= N: 

197 raise IndexError('index (%d) out of range' % i) 

198 if i < 0: 

199 i += N 

200 return self._get_sliceXint(slice(None), i) 

201 

202 def _get_intXint(self, row, col): 

203 raise NotImplementedError() 

204 

205 def _get_intXarray(self, row, col): 

206 raise NotImplementedError() 

207 

208 def _get_intXslice(self, row, col): 

209 raise NotImplementedError() 

210 

211 def _get_sliceXint(self, row, col): 

212 raise NotImplementedError() 

213 

214 def _get_sliceXslice(self, row, col): 

215 raise NotImplementedError() 

216 

217 def _get_sliceXarray(self, row, col): 

218 raise NotImplementedError() 

219 

220 def _get_arrayXint(self, row, col): 

221 raise NotImplementedError() 

222 

223 def _get_arrayXslice(self, row, col): 

224 raise NotImplementedError() 

225 

226 def _get_columnXarray(self, row, col): 

227 raise NotImplementedError() 

228 

229 def _get_arrayXarray(self, row, col): 

230 raise NotImplementedError() 

231 

232 def _set_intXint(self, row, col, x): 

233 raise NotImplementedError() 

234 

235 def _set_arrayXarray(self, row, col, x): 

236 raise NotImplementedError() 

237 

238 def _set_arrayXarray_sparse(self, row, col, x): 

239 # Fall back to densifying x 

240 x = np.asarray(x.toarray(), dtype=self.dtype) 

241 x, _ = _broadcast_arrays(x, row) 

242 self._set_arrayXarray(row, col, x) 

243 

244 

245def _unpack_index(index): 

246 """ Parse index. Always return a tuple of the form (row, col). 

247 Valid type for row/col is integer, slice, or array of integers. 

248 """ 

249 # First, check if indexing with single boolean matrix. 

250 from .base import spmatrix, isspmatrix 

251 if (isinstance(index, (spmatrix, np.ndarray)) and 

252 index.ndim == 2 and index.dtype.kind == 'b'): 

253 return index.nonzero() 

254 

255 # Parse any ellipses. 

256 index = _check_ellipsis(index) 

257 

258 # Next, parse the tuple or object 

259 if isinstance(index, tuple): 

260 if len(index) == 2: 

261 row, col = index 

262 elif len(index) == 1: 

263 row, col = index[0], slice(None) 

264 else: 

265 raise IndexError('invalid number of indices') 

266 else: 

267 idx = _compatible_boolean_index(index) 

268 if idx is None: 

269 row, col = index, slice(None) 

270 elif idx.ndim < 2: 

271 return _boolean_index_to_array(idx), slice(None) 

272 elif idx.ndim == 2: 

273 return idx.nonzero() 

274 # Next, check for validity and transform the index as needed. 

275 if isspmatrix(row) or isspmatrix(col): 

276 # Supporting sparse boolean indexing with both row and col does 

277 # not work because spmatrix.ndim is always 2. 

278 raise IndexError( 

279 'Indexing with sparse matrices is not supported ' 

280 'except boolean indexing where matrix and index ' 

281 'are equal shapes.') 

282 bool_row = _compatible_boolean_index(row) 

283 bool_col = _compatible_boolean_index(col) 

284 if bool_row is not None: 

285 row = _boolean_index_to_array(bool_row) 

286 if bool_col is not None: 

287 col = _boolean_index_to_array(bool_col) 

288 return row, col 

289 

290 

291def _check_ellipsis(index): 

292 """Process indices with Ellipsis. Returns modified index.""" 

293 if index is Ellipsis: 

294 return (slice(None), slice(None)) 

295 

296 if not isinstance(index, tuple): 

297 return index 

298 

299 # TODO: Deprecate this multiple-ellipsis handling, 

300 # as numpy no longer supports it. 

301 

302 # Find first ellipsis. 

303 for j, v in enumerate(index): 

304 if v is Ellipsis: 

305 first_ellipsis = j 

306 break 

307 else: 

308 return index 

309 

310 # Try to expand it using shortcuts for common cases 

311 if len(index) == 1: 

312 return (slice(None), slice(None)) 

313 if len(index) == 2: 

314 if first_ellipsis == 0: 

315 if index[1] is Ellipsis: 

316 return (slice(None), slice(None)) 

317 return (slice(None), index[1]) 

318 return (index[0], slice(None)) 

319 

320 # Expand it using a general-purpose algorithm 

321 tail = [] 

322 for v in index[first_ellipsis+1:]: 

323 if v is not Ellipsis: 

324 tail.append(v) 

325 nd = first_ellipsis + len(tail) 

326 nslice = max(0, 2 - nd) 

327 return index[:first_ellipsis] + (slice(None),)*nslice + tuple(tail) 

328 

329 

330def _maybe_bool_ndarray(idx): 

331 """Returns a compatible array if elements are boolean. 

332 """ 

333 idx = np.asanyarray(idx) 

334 if idx.dtype.kind == 'b': 

335 return idx 

336 return None 

337 

338 

339def _first_element_bool(idx, max_dim=2): 

340 """Returns True if first element of the incompatible 

341 array type is boolean. 

342 """ 

343 if max_dim < 1: 

344 return None 

345 try: 

346 first = next(iter(idx), None) 

347 except TypeError: 

348 return None 

349 if isinstance(first, bool): 

350 return True 

351 return _first_element_bool(first, max_dim-1) 

352 

353 

354def _compatible_boolean_index(idx): 

355 """Returns a boolean index array that can be converted to 

356 integer array. Returns None if no such array exists. 

357 """ 

358 # Presence of attribute `ndim` indicates a compatible array type. 

359 if hasattr(idx, 'ndim') or _first_element_bool(idx): 

360 return _maybe_bool_ndarray(idx) 

361 return None 

362 

363 

364def _boolean_index_to_array(idx): 

365 if idx.ndim > 1: 

366 raise IndexError('invalid index shape') 

367 return np.where(idx)[0]