Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import bz2 

2from collections import Counter 

3from contextlib import contextmanager 

4from datetime import datetime 

5from functools import wraps 

6import gzip 

7import os 

8from shutil import rmtree 

9import string 

10import tempfile 

11from typing import Any, List, Optional, Union, cast 

12import warnings 

13import zipfile 

14 

15import numpy as np 

16from numpy.random import rand, randn 

17 

18from pandas._config.localization import ( # noqa:F401 

19 can_set_locale, 

20 get_locales, 

21 set_locale, 

22) 

23 

24import pandas._libs.testing as _testing 

25from pandas._typing import FilePathOrBuffer, FrameOrSeries 

26from pandas.compat import _get_lzma_file, _import_lzma 

27 

28from pandas.core.dtypes.common import ( 

29 is_bool, 

30 is_categorical_dtype, 

31 is_datetime64_dtype, 

32 is_datetime64tz_dtype, 

33 is_extension_array_dtype, 

34 is_interval_dtype, 

35 is_list_like, 

36 is_number, 

37 is_period_dtype, 

38 is_sequence, 

39 is_timedelta64_dtype, 

40 needs_i8_conversion, 

41) 

42from pandas.core.dtypes.missing import array_equivalent 

43 

44import pandas as pd 

45from pandas import ( 

46 Categorical, 

47 CategoricalIndex, 

48 DataFrame, 

49 DatetimeIndex, 

50 Index, 

51 IntervalIndex, 

52 MultiIndex, 

53 RangeIndex, 

54 Series, 

55 bdate_range, 

56) 

57from pandas.core.algorithms import take_1d 

58from pandas.core.arrays import ( 

59 DatetimeArray, 

60 ExtensionArray, 

61 IntervalArray, 

62 PeriodArray, 

63 TimedeltaArray, 

64 period_array, 

65) 

66 

67from pandas.io.common import urlopen 

68from pandas.io.formats.printing import pprint_thing 

69 

70lzma = _import_lzma() 

71 

72N = 30 

73K = 4 

74_RAISE_NETWORK_ERROR_DEFAULT = False 

75 

76# set testing_mode 

77_testing_mode_warnings = (DeprecationWarning, ResourceWarning) 

78 

79 

80def set_testing_mode(): 

81 # set the testing mode filters 

82 testing_mode = os.environ.get("PANDAS_TESTING_MODE", "None") 

83 if "deprecate" in testing_mode: 

84 warnings.simplefilter("always", _testing_mode_warnings) 

85 

86 

87def reset_testing_mode(): 

88 # reset the testing mode filters 

89 testing_mode = os.environ.get("PANDAS_TESTING_MODE", "None") 

90 if "deprecate" in testing_mode: 

91 warnings.simplefilter("ignore", _testing_mode_warnings) 

92 

93 

94set_testing_mode() 

95 

96 

97def reset_display_options(): 

98 """ 

99 Reset the display options for printing and representing objects. 

100 """ 

101 pd.reset_option("^display.", silent=True) 

102 

103 

104def round_trip_pickle( 

105 obj: Any, path: Optional[FilePathOrBuffer] = None 

106) -> FrameOrSeries: 

107 """ 

108 Pickle an object and then read it again. 

109 

110 Parameters 

111 ---------- 

112 obj : any object 

113 The object to pickle and then re-read. 

114 path : str, path object or file-like object, default None 

115 The path where the pickled object is written and then read. 

116 

117 Returns 

118 ------- 

119 pandas object 

120 The original object that was pickled and then re-read. 

121 """ 

122 _path = path 

123 if _path is None: 

124 _path = f"__{rands(10)}__.pickle" 

125 with ensure_clean(_path) as temp_path: 

126 pd.to_pickle(obj, temp_path) 

127 return pd.read_pickle(temp_path) 

128 

129 

130def round_trip_pathlib(writer, reader, path: Optional[str] = None): 

131 """ 

132 Write an object to file specified by a pathlib.Path and read it back 

133 

134 Parameters 

135 ---------- 

136 writer : callable bound to pandas object 

137 IO writing function (e.g. DataFrame.to_csv ) 

138 reader : callable 

139 IO reading function (e.g. pd.read_csv ) 

140 path : str, default None 

141 The path where the object is written and then read. 

142 

143 Returns 

144 ------- 

145 pandas object 

146 The original object that was serialized and then re-read. 

147 """ 

148 import pytest 

149 

150 Path = pytest.importorskip("pathlib").Path 

151 if path is None: 

152 path = "___pathlib___" 

153 with ensure_clean(path) as path: 

154 writer(Path(path)) 

155 obj = reader(Path(path)) 

156 return obj 

157 

158 

159def round_trip_localpath(writer, reader, path: Optional[str] = None): 

160 """ 

161 Write an object to file specified by a py.path LocalPath and read it back. 

162 

163 Parameters 

164 ---------- 

165 writer : callable bound to pandas object 

166 IO writing function (e.g. DataFrame.to_csv ) 

167 reader : callable 

168 IO reading function (e.g. pd.read_csv ) 

169 path : str, default None 

170 The path where the object is written and then read. 

171 

172 Returns 

173 ------- 

174 pandas object 

175 The original object that was serialized and then re-read. 

176 """ 

177 import pytest 

178 

179 LocalPath = pytest.importorskip("py.path").local 

180 if path is None: 

181 path = "___localpath___" 

182 with ensure_clean(path) as path: 

183 writer(LocalPath(path)) 

184 obj = reader(LocalPath(path)) 

185 return obj 

186 

187 

188@contextmanager 

189def decompress_file(path, compression): 

190 """ 

191 Open a compressed file and return a file object. 

192 

193 Parameters 

194 ---------- 

195 path : str 

196 The path where the file is read from. 

197 

198 compression : {'gzip', 'bz2', 'zip', 'xz', None} 

199 Name of the decompression to use 

200 

201 Returns 

202 ------- 

203 file object 

204 """ 

205 if compression is None: 

206 f = open(path, "rb") 

207 elif compression == "gzip": 

208 f = gzip.open(path, "rb") 

209 elif compression == "bz2": 

210 f = bz2.BZ2File(path, "rb") 

211 elif compression == "xz": 

212 f = _get_lzma_file(lzma)(path, "rb") 

213 elif compression == "zip": 

214 zip_file = zipfile.ZipFile(path) 

215 zip_names = zip_file.namelist() 

216 if len(zip_names) == 1: 

217 f = zip_file.open(zip_names.pop()) 

218 else: 

219 raise ValueError(f"ZIP file {path} error. Only one file per ZIP.") 

220 else: 

221 raise ValueError(f"Unrecognized compression type: {compression}") 

222 

223 try: 

224 yield f 

225 finally: 

226 f.close() 

227 if compression == "zip": 

228 zip_file.close() 

229 

230 

231def write_to_compressed(compression, path, data, dest="test"): 

232 """ 

233 Write data to a compressed file. 

234 

235 Parameters 

236 ---------- 

237 compression : {'gzip', 'bz2', 'zip', 'xz'} 

238 The compression type to use. 

239 path : str 

240 The file path to write the data. 

241 data : str 

242 The data to write. 

243 dest : str, default "test" 

244 The destination file (for ZIP only) 

245 

246 Raises 

247 ------ 

248 ValueError : An invalid compression value was passed in. 

249 """ 

250 if compression == "zip": 

251 import zipfile 

252 

253 compress_method = zipfile.ZipFile 

254 elif compression == "gzip": 

255 import gzip 

256 

257 compress_method = gzip.GzipFile 

258 elif compression == "bz2": 

259 import bz2 

260 

261 compress_method = bz2.BZ2File 

262 elif compression == "xz": 

263 compress_method = _get_lzma_file(lzma) 

264 else: 

265 raise ValueError(f"Unrecognized compression type: {compression}") 

266 

267 if compression == "zip": 

268 mode = "w" 

269 args = (dest, data) 

270 method = "writestr" 

271 else: 

272 mode = "wb" 

273 args = (data,) 

274 method = "write" 

275 

276 with compress_method(path, mode=mode) as f: 

277 getattr(f, method)(*args) 

278 

279 

280def assert_almost_equal( 

281 left, 

282 right, 

283 check_dtype: Union[bool, str] = "equiv", 

284 check_less_precise: Union[bool, int] = False, 

285 **kwargs, 

286): 

287 """ 

288 Check that the left and right objects are approximately equal. 

289 

290 By approximately equal, we refer to objects that are numbers or that 

291 contain numbers which may be equivalent to specific levels of precision. 

292 

293 Parameters 

294 ---------- 

295 left : object 

296 right : object 

297 check_dtype : bool or {'equiv'}, default 'equiv' 

298 Check dtype if both a and b are the same type. If 'equiv' is passed in, 

299 then `RangeIndex` and `Int64Index` are also considered equivalent 

300 when doing type checking. 

301 check_less_precise : bool or int, default False 

302 Specify comparison precision. 5 digits (False) or 3 digits (True) 

303 after decimal points are compared. If int, then specify the number 

304 of digits to compare. 

305 

306 When comparing two numbers, if the first number has magnitude less 

307 than 1e-5, we compare the two numbers directly and check whether 

308 they are equivalent within the specified precision. Otherwise, we 

309 compare the **ratio** of the second number to the first number and 

310 check whether it is equivalent to 1 within the specified precision. 

311 """ 

312 if isinstance(left, pd.Index): 

313 assert_index_equal( 

314 left, 

315 right, 

316 check_exact=False, 

317 exact=check_dtype, 

318 check_less_precise=check_less_precise, 

319 **kwargs, 

320 ) 

321 

322 elif isinstance(left, pd.Series): 

323 assert_series_equal( 

324 left, 

325 right, 

326 check_exact=False, 

327 check_dtype=check_dtype, 

328 check_less_precise=check_less_precise, 

329 **kwargs, 

330 ) 

331 

332 elif isinstance(left, pd.DataFrame): 

333 assert_frame_equal( 

334 left, 

335 right, 

336 check_exact=False, 

337 check_dtype=check_dtype, 

338 check_less_precise=check_less_precise, 

339 **kwargs, 

340 ) 

341 

342 else: 

343 # Other sequences. 

344 if check_dtype: 

345 if is_number(left) and is_number(right): 

346 # Do not compare numeric classes, like np.float64 and float. 

347 pass 

348 elif is_bool(left) and is_bool(right): 

349 # Do not compare bool classes, like np.bool_ and bool. 

350 pass 

351 else: 

352 if isinstance(left, np.ndarray) or isinstance(right, np.ndarray): 

353 obj = "numpy array" 

354 else: 

355 obj = "Input" 

356 assert_class_equal(left, right, obj=obj) 

357 _testing.assert_almost_equal( 

358 left, 

359 right, 

360 check_dtype=check_dtype, 

361 check_less_precise=check_less_precise, 

362 **kwargs, 

363 ) 

364 

365 

366def _check_isinstance(left, right, cls): 

367 """ 

368 Helper method for our assert_* methods that ensures that 

369 the two objects being compared have the right type before 

370 proceeding with the comparison. 

371 

372 Parameters 

373 ---------- 

374 left : The first object being compared. 

375 right : The second object being compared. 

376 cls : The class type to check against. 

377 

378 Raises 

379 ------ 

380 AssertionError : Either `left` or `right` is not an instance of `cls`. 

381 """ 

382 cls_name = cls.__name__ 

383 

384 if not isinstance(left, cls): 

385 raise AssertionError( 

386 f"{cls_name} Expected type {cls}, found {type(left)} instead" 

387 ) 

388 if not isinstance(right, cls): 

389 raise AssertionError( 

390 f"{cls_name} Expected type {cls}, found {type(right)} instead" 

391 ) 

392 

393 

394def assert_dict_equal(left, right, compare_keys: bool = True): 

395 

396 _check_isinstance(left, right, dict) 

397 _testing.assert_dict_equal(left, right, compare_keys=compare_keys) 

398 

399 

400def randbool(size=(), p: float = 0.5): 

401 return rand(*size) <= p 

402 

403 

404RANDS_CHARS = np.array(list(string.ascii_letters + string.digits), dtype=(np.str_, 1)) 

405RANDU_CHARS = np.array( 

406 list("".join(map(chr, range(1488, 1488 + 26))) + string.digits), 

407 dtype=(np.unicode_, 1), 

408) 

409 

410 

411def rands_array(nchars, size, dtype="O"): 

412 """ 

413 Generate an array of byte strings. 

414 """ 

415 retval = ( 

416 np.random.choice(RANDS_CHARS, size=nchars * np.prod(size)) 

417 .view((np.str_, nchars)) 

418 .reshape(size) 

419 ) 

420 if dtype is None: 

421 return retval 

422 else: 

423 return retval.astype(dtype) 

424 

425 

426def randu_array(nchars, size, dtype="O"): 

427 """ 

428 Generate an array of unicode strings. 

429 """ 

430 retval = ( 

431 np.random.choice(RANDU_CHARS, size=nchars * np.prod(size)) 

432 .view((np.unicode_, nchars)) 

433 .reshape(size) 

434 ) 

435 if dtype is None: 

436 return retval 

437 else: 

438 return retval.astype(dtype) 

439 

440 

441def rands(nchars): 

442 """ 

443 Generate one random byte string. 

444 

445 See `rands_array` if you want to create an array of random strings. 

446 

447 """ 

448 return "".join(np.random.choice(RANDS_CHARS, nchars)) 

449 

450 

451def randu(nchars): 

452 """ 

453 Generate one random unicode string. 

454 

455 See `randu_array` if you want to create an array of random unicode strings. 

456 

457 """ 

458 return "".join(np.random.choice(RANDU_CHARS, nchars)) 

459 

460 

461def close(fignum=None): 

462 from matplotlib.pyplot import get_fignums, close as _close 

463 

464 if fignum is None: 

465 for fignum in get_fignums(): 

466 _close(fignum) 

467 else: 

468 _close(fignum) 

469 

470 

471# ----------------------------------------------------------------------------- 

472# contextmanager to ensure the file cleanup 

473 

474 

475@contextmanager 

476def ensure_clean(filename=None, return_filelike=False): 

477 """ 

478 Gets a temporary path and agrees to remove on close. 

479 

480 Parameters 

481 ---------- 

482 filename : str (optional) 

483 if None, creates a temporary file which is then removed when out of 

484 scope. if passed, creates temporary file with filename as ending. 

485 return_filelike : bool (default False) 

486 if True, returns a file-like which is *always* cleaned. Necessary for 

487 savefig and other functions which want to append extensions. 

488 """ 

489 filename = filename or "" 

490 fd = None 

491 

492 if return_filelike: 

493 f = tempfile.TemporaryFile(suffix=filename) 

494 try: 

495 yield f 

496 finally: 

497 f.close() 

498 else: 

499 # don't generate tempfile if using a path with directory specified 

500 if len(os.path.dirname(filename)): 

501 raise ValueError("Can't pass a qualified name to ensure_clean()") 

502 

503 try: 

504 fd, filename = tempfile.mkstemp(suffix=filename) 

505 except UnicodeEncodeError: 

506 import pytest 

507 

508 pytest.skip("no unicode file names on this system") 

509 

510 try: 

511 yield filename 

512 finally: 

513 try: 

514 os.close(fd) 

515 except OSError: 

516 print(f"Couldn't close file descriptor: {fd} (file: {filename})") 

517 try: 

518 if os.path.exists(filename): 

519 os.remove(filename) 

520 except OSError as e: 

521 print(f"Exception on removing file: {e}") 

522 

523 

524@contextmanager 

525def ensure_clean_dir(): 

526 """ 

527 Get a temporary directory path and agrees to remove on close. 

528 

529 Yields 

530 ------ 

531 Temporary directory path 

532 """ 

533 directory_name = tempfile.mkdtemp(suffix="") 

534 try: 

535 yield directory_name 

536 finally: 

537 try: 

538 rmtree(directory_name) 

539 except OSError: 

540 pass 

541 

542 

543@contextmanager 

544def ensure_safe_environment_variables(): 

545 """ 

546 Get a context manager to safely set environment variables 

547 

548 All changes will be undone on close, hence environment variables set 

549 within this contextmanager will neither persist nor change global state. 

550 """ 

551 saved_environ = dict(os.environ) 

552 try: 

553 yield 

554 finally: 

555 os.environ.clear() 

556 os.environ.update(saved_environ) 

557 

558 

559# ----------------------------------------------------------------------------- 

560# Comparators 

561 

562 

563def equalContents(arr1, arr2) -> bool: 

564 """ 

565 Checks if the set of unique elements of arr1 and arr2 are equivalent. 

566 """ 

567 return frozenset(arr1) == frozenset(arr2) 

568 

569 

570def assert_index_equal( 

571 left: Index, 

572 right: Index, 

573 exact: Union[bool, str] = "equiv", 

574 check_names: bool = True, 

575 check_less_precise: Union[bool, int] = False, 

576 check_exact: bool = True, 

577 check_categorical: bool = True, 

578 obj: str = "Index", 

579) -> None: 

580 """ 

581 Check that left and right Index are equal. 

582 

583 Parameters 

584 ---------- 

585 left : Index 

586 right : Index 

587 exact : bool or {'equiv'}, default 'equiv' 

588 Whether to check the Index class, dtype and inferred_type 

589 are identical. If 'equiv', then RangeIndex can be substituted for 

590 Int64Index as well. 

591 check_names : bool, default True 

592 Whether to check the names attribute. 

593 check_less_precise : bool or int, default False 

594 Specify comparison precision. Only used when check_exact is False. 

595 5 digits (False) or 3 digits (True) after decimal points are compared. 

596 If int, then specify the digits to compare. 

597 check_exact : bool, default True 

598 Whether to compare number exactly. 

599 check_categorical : bool, default True 

600 Whether to compare internal Categorical exactly. 

601 obj : str, default 'Index' 

602 Specify object name being compared, internally used to show appropriate 

603 assertion message. 

604 """ 

605 __tracebackhide__ = True 

606 

607 def _check_types(l, r, obj="Index"): 

608 if exact: 

609 assert_class_equal(l, r, exact=exact, obj=obj) 

610 

611 # Skip exact dtype checking when `check_categorical` is False 

612 if check_categorical: 

613 assert_attr_equal("dtype", l, r, obj=obj) 

614 

615 # allow string-like to have different inferred_types 

616 if l.inferred_type in ("string", "unicode"): 

617 assert r.inferred_type in ("string", "unicode") 

618 else: 

619 assert_attr_equal("inferred_type", l, r, obj=obj) 

620 

621 def _get_ilevel_values(index, level): 

622 # accept level number only 

623 unique = index.levels[level] 

624 level_codes = index.codes[level] 

625 filled = take_1d(unique._values, level_codes, fill_value=unique._na_value) 

626 values = unique._shallow_copy(filled, name=index.names[level]) 

627 return values 

628 

629 # instance validation 

630 _check_isinstance(left, right, Index) 

631 

632 # class / dtype comparison 

633 _check_types(left, right, obj=obj) 

634 

635 # level comparison 

636 if left.nlevels != right.nlevels: 

637 msg1 = f"{obj} levels are different" 

638 msg2 = f"{left.nlevels}, {left}" 

639 msg3 = f"{right.nlevels}, {right}" 

640 raise_assert_detail(obj, msg1, msg2, msg3) 

641 

642 # length comparison 

643 if len(left) != len(right): 

644 msg1 = f"{obj} length are different" 

645 msg2 = f"{len(left)}, {left}" 

646 msg3 = f"{len(right)}, {right}" 

647 raise_assert_detail(obj, msg1, msg2, msg3) 

648 

649 # MultiIndex special comparison for little-friendly error messages 

650 if left.nlevels > 1: 

651 left = cast(MultiIndex, left) 

652 right = cast(MultiIndex, right) 

653 

654 for level in range(left.nlevels): 

655 # cannot use get_level_values here because it can change dtype 

656 llevel = _get_ilevel_values(left, level) 

657 rlevel = _get_ilevel_values(right, level) 

658 

659 lobj = f"MultiIndex level [{level}]" 

660 assert_index_equal( 

661 llevel, 

662 rlevel, 

663 exact=exact, 

664 check_names=check_names, 

665 check_less_precise=check_less_precise, 

666 check_exact=check_exact, 

667 obj=lobj, 

668 ) 

669 # get_level_values may change dtype 

670 _check_types(left.levels[level], right.levels[level], obj=obj) 

671 

672 # skip exact index checking when `check_categorical` is False 

673 if check_exact and check_categorical: 

674 if not left.equals(right): 

675 diff = np.sum((left.values != right.values).astype(int)) * 100.0 / len(left) 

676 msg = f"{obj} values are different ({np.round(diff, 5)} %)" 

677 raise_assert_detail(obj, msg, left, right) 

678 else: 

679 _testing.assert_almost_equal( 

680 left.values, 

681 right.values, 

682 check_less_precise=check_less_precise, 

683 check_dtype=exact, 

684 obj=obj, 

685 lobj=left, 

686 robj=right, 

687 ) 

688 

689 # metadata comparison 

690 if check_names: 

691 assert_attr_equal("names", left, right, obj=obj) 

692 if isinstance(left, pd.PeriodIndex) or isinstance(right, pd.PeriodIndex): 

693 assert_attr_equal("freq", left, right, obj=obj) 

694 if isinstance(left, pd.IntervalIndex) or isinstance(right, pd.IntervalIndex): 

695 assert_interval_array_equal(left.values, right.values) 

696 

697 if check_categorical: 

698 if is_categorical_dtype(left) or is_categorical_dtype(right): 

699 assert_categorical_equal(left.values, right.values, obj=f"{obj} category") 

700 

701 

702def assert_class_equal(left, right, exact: Union[bool, str] = True, obj="Input"): 

703 """ 

704 Checks classes are equal. 

705 """ 

706 __tracebackhide__ = True 

707 

708 def repr_class(x): 

709 if isinstance(x, Index): 

710 # return Index as it is to include values in the error message 

711 return x 

712 

713 try: 

714 return type(x).__name__ 

715 except AttributeError: 

716 return repr(type(x)) 

717 

718 if exact == "equiv": 

719 if type(left) != type(right): 

720 # allow equivalence of Int64Index/RangeIndex 

721 types = {type(left).__name__, type(right).__name__} 

722 if len(types - {"Int64Index", "RangeIndex"}): 

723 msg = f"{obj} classes are not equivalent" 

724 raise_assert_detail(obj, msg, repr_class(left), repr_class(right)) 

725 elif exact: 

726 if type(left) != type(right): 

727 msg = f"{obj} classes are different" 

728 raise_assert_detail(obj, msg, repr_class(left), repr_class(right)) 

729 

730 

731def assert_attr_equal(attr, left, right, obj="Attributes"): 

732 """checks attributes are equal. Both objects must have attribute. 

733 

734 Parameters 

735 ---------- 

736 attr : str 

737 Attribute name being compared. 

738 left : object 

739 right : object 

740 obj : str, default 'Attributes' 

741 Specify object name being compared, internally used to show appropriate 

742 assertion message 

743 """ 

744 __tracebackhide__ = True 

745 

746 left_attr = getattr(left, attr) 

747 right_attr = getattr(right, attr) 

748 

749 if left_attr is right_attr: 

750 return True 

751 elif ( 

752 is_number(left_attr) 

753 and np.isnan(left_attr) 

754 and is_number(right_attr) 

755 and np.isnan(right_attr) 

756 ): 

757 # np.nan 

758 return True 

759 

760 try: 

761 result = left_attr == right_attr 

762 except TypeError: 

763 # datetimetz on rhs may raise TypeError 

764 result = False 

765 if not isinstance(result, bool): 

766 result = result.all() 

767 

768 if result: 

769 return True 

770 else: 

771 msg = f'Attribute "{attr}" are different' 

772 raise_assert_detail(obj, msg, left_attr, right_attr) 

773 

774 

775def assert_is_valid_plot_return_object(objs): 

776 import matplotlib.pyplot as plt 

777 

778 if isinstance(objs, (pd.Series, np.ndarray)): 

779 for el in objs.ravel(): 

780 msg = ( 

781 "one of 'objs' is not a matplotlib Axes instance, " 

782 f"type encountered {repr(type(el).__name__)}" 

783 ) 

784 assert isinstance(el, (plt.Axes, dict)), msg 

785 else: 

786 msg = ( 

787 "objs is neither an ndarray of Artist instances nor a single " 

788 "ArtistArtist instance, tuple, or dict, 'objs' is a " 

789 f"{repr(type(objs).__name__)}" 

790 ) 

791 assert isinstance(objs, (plt.Artist, tuple, dict)), msg 

792 

793 

794def isiterable(obj): 

795 return hasattr(obj, "__iter__") 

796 

797 

798def assert_is_sorted(seq): 

799 """Assert that the sequence is sorted.""" 

800 if isinstance(seq, (Index, Series)): 

801 seq = seq.values 

802 # sorting does not change precisions 

803 assert_numpy_array_equal(seq, np.sort(np.array(seq))) 

804 

805 

806def assert_categorical_equal( 

807 left, right, check_dtype=True, check_category_order=True, obj="Categorical" 

808): 

809 """Test that Categoricals are equivalent. 

810 

811 Parameters 

812 ---------- 

813 left : Categorical 

814 right : Categorical 

815 check_dtype : bool, default True 

816 Check that integer dtype of the codes are the same 

817 check_category_order : bool, default True 

818 Whether the order of the categories should be compared, which 

819 implies identical integer codes. If False, only the resulting 

820 values are compared. The ordered attribute is 

821 checked regardless. 

822 obj : str, default 'Categorical' 

823 Specify object name being compared, internally used to show appropriate 

824 assertion message 

825 """ 

826 _check_isinstance(left, right, Categorical) 

827 

828 if check_category_order: 

829 assert_index_equal(left.categories, right.categories, obj=f"{obj}.categories") 

830 assert_numpy_array_equal( 

831 left.codes, right.codes, check_dtype=check_dtype, obj=f"{obj}.codes", 

832 ) 

833 else: 

834 assert_index_equal( 

835 left.categories.sort_values(), 

836 right.categories.sort_values(), 

837 obj=f"{obj}.categories", 

838 ) 

839 assert_index_equal( 

840 left.categories.take(left.codes), 

841 right.categories.take(right.codes), 

842 obj=f"{obj}.values", 

843 ) 

844 

845 assert_attr_equal("ordered", left, right, obj=obj) 

846 

847 

848def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"): 

849 """Test that two IntervalArrays are equivalent. 

850 

851 Parameters 

852 ---------- 

853 left, right : IntervalArray 

854 The IntervalArrays to compare. 

855 exact : bool or {'equiv'}, default 'equiv' 

856 Whether to check the Index class, dtype and inferred_type 

857 are identical. If 'equiv', then RangeIndex can be substituted for 

858 Int64Index as well. 

859 obj : str, default 'IntervalArray' 

860 Specify object name being compared, internally used to show appropriate 

861 assertion message 

862 """ 

863 _check_isinstance(left, right, IntervalArray) 

864 

865 assert_index_equal(left.left, right.left, exact=exact, obj=f"{obj}.left") 

866 assert_index_equal(left.right, right.right, exact=exact, obj=f"{obj}.left") 

867 assert_attr_equal("closed", left, right, obj=obj) 

868 

869 

870def assert_period_array_equal(left, right, obj="PeriodArray"): 

871 _check_isinstance(left, right, PeriodArray) 

872 

873 assert_numpy_array_equal(left._data, right._data, obj=f"{obj}.values") 

874 assert_attr_equal("freq", left, right, obj=obj) 

875 

876 

877def assert_datetime_array_equal(left, right, obj="DatetimeArray"): 

878 __tracebackhide__ = True 

879 _check_isinstance(left, right, DatetimeArray) 

880 

881 assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data") 

882 assert_attr_equal("freq", left, right, obj=obj) 

883 assert_attr_equal("tz", left, right, obj=obj) 

884 

885 

886def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"): 

887 __tracebackhide__ = True 

888 _check_isinstance(left, right, TimedeltaArray) 

889 assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data") 

890 assert_attr_equal("freq", left, right, obj=obj) 

891 

892 

893def raise_assert_detail(obj, message, left, right, diff=None): 

894 __tracebackhide__ = True 

895 

896 if isinstance(left, np.ndarray): 

897 left = pprint_thing(left) 

898 elif is_categorical_dtype(left): 

899 left = repr(left) 

900 

901 if isinstance(right, np.ndarray): 

902 right = pprint_thing(right) 

903 elif is_categorical_dtype(right): 

904 right = repr(right) 

905 

906 msg = f"""{obj} are different 

907 

908{message} 

909[left]: {left} 

910[right]: {right}""" 

911 

912 if diff is not None: 

913 msg += f"\n[diff]: {diff}" 

914 

915 raise AssertionError(msg) 

916 

917 

918def assert_numpy_array_equal( 

919 left, 

920 right, 

921 strict_nan=False, 

922 check_dtype=True, 

923 err_msg=None, 

924 check_same=None, 

925 obj="numpy array", 

926): 

927 """ 

928 Check that 'np.ndarray' is equivalent. 

929 

930 Parameters 

931 ---------- 

932 left, right : numpy.ndarray or iterable 

933 The two arrays to be compared. 

934 strict_nan : bool, default False 

935 If True, consider NaN and None to be different. 

936 check_dtype : bool, default True 

937 Check dtype if both a and b are np.ndarray. 

938 err_msg : str, default None 

939 If provided, used as assertion message. 

940 check_same : None|'copy'|'same', default None 

941 Ensure left and right refer/do not refer to the same memory area. 

942 obj : str, default 'numpy array' 

943 Specify object name being compared, internally used to show appropriate 

944 assertion message. 

945 """ 

946 __tracebackhide__ = True 

947 

948 # instance validation 

949 # Show a detailed error message when classes are different 

950 assert_class_equal(left, right, obj=obj) 

951 # both classes must be an np.ndarray 

952 _check_isinstance(left, right, np.ndarray) 

953 

954 def _get_base(obj): 

955 return obj.base if getattr(obj, "base", None) is not None else obj 

956 

957 left_base = _get_base(left) 

958 right_base = _get_base(right) 

959 

960 if check_same == "same": 

961 if left_base is not right_base: 

962 raise AssertionError(f"{repr(left_base)} is not {repr(right_base)}") 

963 elif check_same == "copy": 

964 if left_base is right_base: 

965 raise AssertionError(f"{repr(left_base)} is {repr(right_base)}") 

966 

967 def _raise(left, right, err_msg): 

968 if err_msg is None: 

969 if left.shape != right.shape: 

970 raise_assert_detail( 

971 obj, f"{obj} shapes are different", left.shape, right.shape, 

972 ) 

973 

974 diff = 0 

975 for l, r in zip(left, right): 

976 # count up differences 

977 if not array_equivalent(l, r, strict_nan=strict_nan): 

978 diff += 1 

979 

980 diff = diff * 100.0 / left.size 

981 msg = f"{obj} values are different ({np.round(diff, 5)} %)" 

982 raise_assert_detail(obj, msg, left, right) 

983 

984 raise AssertionError(err_msg) 

985 

986 # compare shape and values 

987 if not array_equivalent(left, right, strict_nan=strict_nan): 

988 _raise(left, right, err_msg) 

989 

990 if check_dtype: 

991 if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): 

992 assert_attr_equal("dtype", left, right, obj=obj) 

993 

994 

995def assert_extension_array_equal( 

996 left, right, check_dtype=True, check_less_precise=False, check_exact=False 

997): 

998 """Check that left and right ExtensionArrays are equal. 

999 

1000 Parameters 

1001 ---------- 

1002 left, right : ExtensionArray 

1003 The two arrays to compare 

1004 check_dtype : bool, default True 

1005 Whether to check if the ExtensionArray dtypes are identical. 

1006 check_less_precise : bool or int, default False 

1007 Specify comparison precision. Only used when check_exact is False. 

1008 5 digits (False) or 3 digits (True) after decimal points are compared. 

1009 If int, then specify the digits to compare. 

1010 check_exact : bool, default False 

1011 Whether to compare number exactly. 

1012 

1013 Notes 

1014 ----- 

1015 Missing values are checked separately from valid values. 

1016 A mask of missing values is computed for each and checked to match. 

1017 The remaining all-valid values are cast to object dtype and checked. 

1018 """ 

1019 assert isinstance(left, ExtensionArray), "left is not an ExtensionArray" 

1020 assert isinstance(right, ExtensionArray), "right is not an ExtensionArray" 

1021 if check_dtype: 

1022 assert_attr_equal("dtype", left, right, obj="ExtensionArray") 

1023 

1024 if hasattr(left, "asi8") and type(right) == type(left): 

1025 # Avoid slow object-dtype comparisons 

1026 assert_numpy_array_equal(left.asi8, right.asi8) 

1027 return 

1028 

1029 left_na = np.asarray(left.isna()) 

1030 right_na = np.asarray(right.isna()) 

1031 assert_numpy_array_equal(left_na, right_na, obj="ExtensionArray NA mask") 

1032 

1033 left_valid = np.asarray(left[~left_na].astype(object)) 

1034 right_valid = np.asarray(right[~right_na].astype(object)) 

1035 if check_exact: 

1036 assert_numpy_array_equal(left_valid, right_valid, obj="ExtensionArray") 

1037 else: 

1038 _testing.assert_almost_equal( 

1039 left_valid, 

1040 right_valid, 

1041 check_dtype=check_dtype, 

1042 check_less_precise=check_less_precise, 

1043 obj="ExtensionArray", 

1044 ) 

1045 

1046 

1047# This could be refactored to use the NDFrame.equals method 

1048def assert_series_equal( 

1049 left, 

1050 right, 

1051 check_dtype=True, 

1052 check_index_type="equiv", 

1053 check_series_type=True, 

1054 check_less_precise=False, 

1055 check_names=True, 

1056 check_exact=False, 

1057 check_datetimelike_compat=False, 

1058 check_categorical=True, 

1059 check_category_order=True, 

1060 obj="Series", 

1061): 

1062 """ 

1063 Check that left and right Series are equal. 

1064 

1065 Parameters 

1066 ---------- 

1067 left : Series 

1068 right : Series 

1069 check_dtype : bool, default True 

1070 Whether to check the Series dtype is identical. 

1071 check_index_type : bool or {'equiv'}, default 'equiv' 

1072 Whether to check the Index class, dtype and inferred_type 

1073 are identical. 

1074 check_series_type : bool, default True 

1075 Whether to check the Series class is identical. 

1076 check_less_precise : bool or int, default False 

1077 Specify comparison precision. Only used when check_exact is False. 

1078 5 digits (False) or 3 digits (True) after decimal points are compared. 

1079 If int, then specify the digits to compare. 

1080 

1081 When comparing two numbers, if the first number has magnitude less 

1082 than 1e-5, we compare the two numbers directly and check whether 

1083 they are equivalent within the specified precision. Otherwise, we 

1084 compare the **ratio** of the second number to the first number and 

1085 check whether it is equivalent to 1 within the specified precision. 

1086 check_names : bool, default True 

1087 Whether to check the Series and Index names attribute. 

1088 check_exact : bool, default False 

1089 Whether to compare number exactly. 

1090 check_datetimelike_compat : bool, default False 

1091 Compare datetime-like which is comparable ignoring dtype. 

1092 check_categorical : bool, default True 

1093 Whether to compare internal Categorical exactly. 

1094 check_category_order : bool, default True 

1095 Whether to compare category order of internal Categoricals 

1096 

1097 .. versionadded:: 1.0.2 

1098 obj : str, default 'Series' 

1099 Specify object name being compared, internally used to show appropriate 

1100 assertion message. 

1101 """ 

1102 __tracebackhide__ = True 

1103 

1104 # instance validation 

1105 _check_isinstance(left, right, Series) 

1106 

1107 if check_series_type: 

1108 # ToDo: There are some tests using rhs is sparse 

1109 # lhs is dense. Should use assert_class_equal in future 

1110 assert isinstance(left, type(right)) 

1111 # assert_class_equal(left, right, obj=obj) 

1112 

1113 # length comparison 

1114 if len(left) != len(right): 

1115 msg1 = f"{len(left)}, {left.index}" 

1116 msg2 = f"{len(right)}, {right.index}" 

1117 raise_assert_detail(obj, "Series length are different", msg1, msg2) 

1118 

1119 # index comparison 

1120 assert_index_equal( 

1121 left.index, 

1122 right.index, 

1123 exact=check_index_type, 

1124 check_names=check_names, 

1125 check_less_precise=check_less_precise, 

1126 check_exact=check_exact, 

1127 check_categorical=check_categorical, 

1128 obj=f"{obj}.index", 

1129 ) 

1130 

1131 if check_dtype: 

1132 # We want to skip exact dtype checking when `check_categorical` 

1133 # is False. We'll still raise if only one is a `Categorical`, 

1134 # regardless of `check_categorical` 

1135 if ( 

1136 is_categorical_dtype(left) 

1137 and is_categorical_dtype(right) 

1138 and not check_categorical 

1139 ): 

1140 pass 

1141 else: 

1142 assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}") 

1143 

1144 if check_exact: 

1145 assert_numpy_array_equal( 

1146 left._internal_get_values(), 

1147 right._internal_get_values(), 

1148 check_dtype=check_dtype, 

1149 obj=str(obj), 

1150 ) 

1151 elif check_datetimelike_compat: 

1152 # we want to check only if we have compat dtypes 

1153 # e.g. integer and M|m are NOT compat, but we can simply check 

1154 # the values in that case 

1155 if needs_i8_conversion(left) or needs_i8_conversion(right): 

1156 

1157 # datetimelike may have different objects (e.g. datetime.datetime 

1158 # vs Timestamp) but will compare equal 

1159 if not Index(left.values).equals(Index(right.values)): 

1160 msg = ( 

1161 f"[datetimelike_compat=True] {left.values} " 

1162 f"is not equal to {right.values}." 

1163 ) 

1164 raise AssertionError(msg) 

1165 else: 

1166 assert_numpy_array_equal( 

1167 left._internal_get_values(), 

1168 right._internal_get_values(), 

1169 check_dtype=check_dtype, 

1170 ) 

1171 elif is_interval_dtype(left) or is_interval_dtype(right): 

1172 assert_interval_array_equal(left.array, right.array) 

1173 elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype): 

1174 # .values is an ndarray, but ._values is the ExtensionArray. 

1175 # TODO: Use .array 

1176 assert is_extension_array_dtype(right.dtype) 

1177 assert_extension_array_equal(left._values, right._values) 

1178 elif ( 

1179 is_extension_array_dtype(left) 

1180 and not is_categorical_dtype(left) 

1181 and is_extension_array_dtype(right) 

1182 and not is_categorical_dtype(right) 

1183 ): 

1184 assert_extension_array_equal(left.array, right.array) 

1185 else: 

1186 _testing.assert_almost_equal( 

1187 left._internal_get_values(), 

1188 right._internal_get_values(), 

1189 check_less_precise=check_less_precise, 

1190 check_dtype=check_dtype, 

1191 obj=str(obj), 

1192 ) 

1193 

1194 # metadata comparison 

1195 if check_names: 

1196 assert_attr_equal("name", left, right, obj=obj) 

1197 

1198 if check_categorical: 

1199 if is_categorical_dtype(left) or is_categorical_dtype(right): 

1200 assert_categorical_equal( 

1201 left.values, 

1202 right.values, 

1203 obj=f"{obj} category", 

1204 check_category_order=check_category_order, 

1205 ) 

1206 

1207 

1208# This could be refactored to use the NDFrame.equals method 

1209def assert_frame_equal( 

1210 left, 

1211 right, 

1212 check_dtype=True, 

1213 check_index_type="equiv", 

1214 check_column_type="equiv", 

1215 check_frame_type=True, 

1216 check_less_precise=False, 

1217 check_names=True, 

1218 by_blocks=False, 

1219 check_exact=False, 

1220 check_datetimelike_compat=False, 

1221 check_categorical=True, 

1222 check_like=False, 

1223 obj="DataFrame", 

1224): 

1225 """ 

1226 Check that left and right DataFrame are equal. 

1227 

1228 This function is intended to compare two DataFrames and output any 

1229 differences. Is is mostly intended for use in unit tests. 

1230 Additional parameters allow varying the strictness of the 

1231 equality checks performed. 

1232 

1233 Parameters 

1234 ---------- 

1235 left : DataFrame 

1236 First DataFrame to compare. 

1237 right : DataFrame 

1238 Second DataFrame to compare. 

1239 check_dtype : bool, default True 

1240 Whether to check the DataFrame dtype is identical. 

1241 check_index_type : bool or {'equiv'}, default 'equiv' 

1242 Whether to check the Index class, dtype and inferred_type 

1243 are identical. 

1244 check_column_type : bool or {'equiv'}, default 'equiv' 

1245 Whether to check the columns class, dtype and inferred_type 

1246 are identical. Is passed as the ``exact`` argument of 

1247 :func:`assert_index_equal`. 

1248 check_frame_type : bool, default True 

1249 Whether to check the DataFrame class is identical. 

1250 check_less_precise : bool or int, default False 

1251 Specify comparison precision. Only used when check_exact is False. 

1252 5 digits (False) or 3 digits (True) after decimal points are compared. 

1253 If int, then specify the digits to compare. 

1254 

1255 When comparing two numbers, if the first number has magnitude less 

1256 than 1e-5, we compare the two numbers directly and check whether 

1257 they are equivalent within the specified precision. Otherwise, we 

1258 compare the **ratio** of the second number to the first number and 

1259 check whether it is equivalent to 1 within the specified precision. 

1260 check_names : bool, default True 

1261 Whether to check that the `names` attribute for both the `index` 

1262 and `column` attributes of the DataFrame is identical. 

1263 by_blocks : bool, default False 

1264 Specify how to compare internal data. If False, compare by columns. 

1265 If True, compare by blocks. 

1266 check_exact : bool, default False 

1267 Whether to compare number exactly. 

1268 check_datetimelike_compat : bool, default False 

1269 Compare datetime-like which is comparable ignoring dtype. 

1270 check_categorical : bool, default True 

1271 Whether to compare internal Categorical exactly. 

1272 check_like : bool, default False 

1273 If True, ignore the order of index & columns. 

1274 Note: index labels must match their respective rows 

1275 (same as in columns) - same labels must be with the same data. 

1276 obj : str, default 'DataFrame' 

1277 Specify object name being compared, internally used to show appropriate 

1278 assertion message. 

1279 

1280 See Also 

1281 -------- 

1282 assert_series_equal : Equivalent method for asserting Series equality. 

1283 DataFrame.equals : Check DataFrame equality. 

1284 

1285 Examples 

1286 -------- 

1287 This example shows comparing two DataFrames that are equal 

1288 but with columns of differing dtypes. 

1289 

1290 >>> from pandas._testing import assert_frame_equal 

1291 >>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) 

1292 >>> df2 = pd.DataFrame({'a': [1, 2], 'b': [3.0, 4.0]}) 

1293 

1294 df1 equals itself. 

1295 

1296 >>> assert_frame_equal(df1, df1) 

1297 

1298 df1 differs from df2 as column 'b' is of a different type. 

1299 

1300 >>> assert_frame_equal(df1, df2) 

1301 Traceback (most recent call last): 

1302 ... 

1303 AssertionError: Attributes of DataFrame.iloc[:, 1] (column name="b") are different 

1304 

1305 Attribute "dtype" are different 

1306 [left]: int64 

1307 [right]: float64 

1308 

1309 Ignore differing dtypes in columns with check_dtype. 

1310 

1311 >>> assert_frame_equal(df1, df2, check_dtype=False) 

1312 """ 

1313 __tracebackhide__ = True 

1314 

1315 # instance validation 

1316 _check_isinstance(left, right, DataFrame) 

1317 

1318 if check_frame_type: 

1319 assert isinstance(left, type(right)) 

1320 # assert_class_equal(left, right, obj=obj) 

1321 

1322 # shape comparison 

1323 if left.shape != right.shape: 

1324 raise_assert_detail( 

1325 obj, f"{obj} shape mismatch", f"{repr(left.shape)}", f"{repr(right.shape)}", 

1326 ) 

1327 

1328 if check_like: 

1329 left, right = left.reindex_like(right), right 

1330 

1331 # index comparison 

1332 assert_index_equal( 

1333 left.index, 

1334 right.index, 

1335 exact=check_index_type, 

1336 check_names=check_names, 

1337 check_less_precise=check_less_precise, 

1338 check_exact=check_exact, 

1339 check_categorical=check_categorical, 

1340 obj=f"{obj}.index", 

1341 ) 

1342 

1343 # column comparison 

1344 assert_index_equal( 

1345 left.columns, 

1346 right.columns, 

1347 exact=check_column_type, 

1348 check_names=check_names, 

1349 check_less_precise=check_less_precise, 

1350 check_exact=check_exact, 

1351 check_categorical=check_categorical, 

1352 obj=f"{obj}.columns", 

1353 ) 

1354 

1355 # compare by blocks 

1356 if by_blocks: 

1357 rblocks = right._to_dict_of_blocks() 

1358 lblocks = left._to_dict_of_blocks() 

1359 for dtype in list(set(list(lblocks.keys()) + list(rblocks.keys()))): 

1360 assert dtype in lblocks 

1361 assert dtype in rblocks 

1362 assert_frame_equal( 

1363 lblocks[dtype], rblocks[dtype], check_dtype=check_dtype, obj=obj 

1364 ) 

1365 

1366 # compare by columns 

1367 else: 

1368 for i, col in enumerate(left.columns): 

1369 assert col in right 

1370 lcol = left.iloc[:, i] 

1371 rcol = right.iloc[:, i] 

1372 assert_series_equal( 

1373 lcol, 

1374 rcol, 

1375 check_dtype=check_dtype, 

1376 check_index_type=check_index_type, 

1377 check_less_precise=check_less_precise, 

1378 check_exact=check_exact, 

1379 check_names=check_names, 

1380 check_datetimelike_compat=check_datetimelike_compat, 

1381 check_categorical=check_categorical, 

1382 obj=f'{obj}.iloc[:, {i}] (column name="{col}")', 

1383 ) 

1384 

1385 

1386def assert_equal(left, right, **kwargs): 

1387 """ 

1388 Wrapper for tm.assert_*_equal to dispatch to the appropriate test function. 

1389 

1390 Parameters 

1391 ---------- 

1392 left, right : Index, Series, DataFrame, ExtensionArray, or np.ndarray 

1393 The two items to be compared. 

1394 **kwargs 

1395 All keyword arguments are passed through to the underlying assert method. 

1396 """ 

1397 __tracebackhide__ = True 

1398 

1399 if isinstance(left, pd.Index): 

1400 assert_index_equal(left, right, **kwargs) 

1401 elif isinstance(left, pd.Series): 

1402 assert_series_equal(left, right, **kwargs) 

1403 elif isinstance(left, pd.DataFrame): 

1404 assert_frame_equal(left, right, **kwargs) 

1405 elif isinstance(left, IntervalArray): 

1406 assert_interval_array_equal(left, right, **kwargs) 

1407 elif isinstance(left, PeriodArray): 

1408 assert_period_array_equal(left, right, **kwargs) 

1409 elif isinstance(left, DatetimeArray): 

1410 assert_datetime_array_equal(left, right, **kwargs) 

1411 elif isinstance(left, TimedeltaArray): 

1412 assert_timedelta_array_equal(left, right, **kwargs) 

1413 elif isinstance(left, ExtensionArray): 

1414 assert_extension_array_equal(left, right, **kwargs) 

1415 elif isinstance(left, np.ndarray): 

1416 assert_numpy_array_equal(left, right, **kwargs) 

1417 elif isinstance(left, str): 

1418 assert kwargs == {} 

1419 assert left == right 

1420 else: 

1421 raise NotImplementedError(type(left)) 

1422 

1423 

1424def box_expected(expected, box_cls, transpose=True): 

1425 """ 

1426 Helper function to wrap the expected output of a test in a given box_class. 

1427 

1428 Parameters 

1429 ---------- 

1430 expected : np.ndarray, Index, Series 

1431 box_cls : {Index, Series, DataFrame} 

1432 

1433 Returns 

1434 ------- 

1435 subclass of box_cls 

1436 """ 

1437 if box_cls is pd.Index: 

1438 expected = pd.Index(expected) 

1439 elif box_cls is pd.Series: 

1440 expected = pd.Series(expected) 

1441 elif box_cls is pd.DataFrame: 

1442 expected = pd.Series(expected).to_frame() 

1443 if transpose: 

1444 # for vector operations, we we need a DataFrame to be a single-row, 

1445 # not a single-column, in order to operate against non-DataFrame 

1446 # vectors of the same length. 

1447 expected = expected.T 

1448 elif box_cls is PeriodArray: 

1449 # the PeriodArray constructor is not as flexible as period_array 

1450 expected = period_array(expected) 

1451 elif box_cls is DatetimeArray: 

1452 expected = DatetimeArray(expected) 

1453 elif box_cls is TimedeltaArray: 

1454 expected = TimedeltaArray(expected) 

1455 elif box_cls is np.ndarray: 

1456 expected = np.array(expected) 

1457 elif box_cls is to_array: 

1458 expected = to_array(expected) 

1459 else: 

1460 raise NotImplementedError(box_cls) 

1461 return expected 

1462 

1463 

1464def to_array(obj): 

1465 # temporary implementation until we get pd.array in place 

1466 if is_period_dtype(obj): 

1467 return period_array(obj) 

1468 elif is_datetime64_dtype(obj) or is_datetime64tz_dtype(obj): 

1469 return DatetimeArray._from_sequence(obj) 

1470 elif is_timedelta64_dtype(obj): 

1471 return TimedeltaArray._from_sequence(obj) 

1472 else: 

1473 return np.array(obj) 

1474 

1475 

1476# ----------------------------------------------------------------------------- 

1477# Sparse 

1478 

1479 

1480def assert_sp_array_equal( 

1481 left, 

1482 right, 

1483 check_dtype=True, 

1484 check_kind=True, 

1485 check_fill_value=True, 

1486 consolidate_block_indices=False, 

1487): 

1488 """Check that the left and right SparseArray are equal. 

1489 

1490 Parameters 

1491 ---------- 

1492 left : SparseArray 

1493 right : SparseArray 

1494 check_dtype : bool, default True 

1495 Whether to check the data dtype is identical. 

1496 check_kind : bool, default True 

1497 Whether to just the kind of the sparse index for each column. 

1498 check_fill_value : bool, default True 

1499 Whether to check that left.fill_value matches right.fill_value 

1500 consolidate_block_indices : bool, default False 

1501 Whether to consolidate contiguous blocks for sparse arrays with 

1502 a BlockIndex. Some operations, e.g. concat, will end up with 

1503 block indices that could be consolidated. Setting this to true will 

1504 create a new BlockIndex for that array, with consolidated 

1505 block indices. 

1506 """ 

1507 

1508 _check_isinstance(left, right, pd.arrays.SparseArray) 

1509 

1510 assert_numpy_array_equal(left.sp_values, right.sp_values, check_dtype=check_dtype) 

1511 

1512 # SparseIndex comparison 

1513 assert isinstance(left.sp_index, pd._libs.sparse.SparseIndex) 

1514 assert isinstance(right.sp_index, pd._libs.sparse.SparseIndex) 

1515 

1516 if not check_kind: 

1517 left_index = left.sp_index.to_block_index() 

1518 right_index = right.sp_index.to_block_index() 

1519 else: 

1520 left_index = left.sp_index 

1521 right_index = right.sp_index 

1522 

1523 if consolidate_block_indices and left.kind == "block": 

1524 # we'll probably remove this hack... 

1525 left_index = left_index.to_int_index().to_block_index() 

1526 right_index = right_index.to_int_index().to_block_index() 

1527 

1528 if not left_index.equals(right_index): 

1529 raise_assert_detail( 

1530 "SparseArray.index", "index are not equal", left_index, right_index 

1531 ) 

1532 else: 

1533 # Just ensure a 

1534 pass 

1535 

1536 if check_fill_value: 

1537 assert_attr_equal("fill_value", left, right) 

1538 if check_dtype: 

1539 assert_attr_equal("dtype", left, right) 

1540 assert_numpy_array_equal(left.to_dense(), right.to_dense(), check_dtype=check_dtype) 

1541 

1542 

1543# ----------------------------------------------------------------------------- 

1544# Others 

1545 

1546 

1547def assert_contains_all(iterable, dic): 

1548 for k in iterable: 

1549 assert k in dic, f"Did not contain item: {repr(k)}" 

1550 

1551 

1552def assert_copy(iter1, iter2, **eql_kwargs): 

1553 """ 

1554 iter1, iter2: iterables that produce elements 

1555 comparable with assert_almost_equal 

1556 

1557 Checks that the elements are equal, but not 

1558 the same object. (Does not check that items 

1559 in sequences are also not the same object) 

1560 """ 

1561 for elem1, elem2 in zip(iter1, iter2): 

1562 assert_almost_equal(elem1, elem2, **eql_kwargs) 

1563 msg = ( 

1564 f"Expected object {repr(type(elem1))} and object {repr(type(elem2))} to be " 

1565 "different objects, but they were the same object." 

1566 ) 

1567 assert elem1 is not elem2, msg 

1568 

1569 

1570def getCols(k): 

1571 return string.ascii_uppercase[:k] 

1572 

1573 

1574# make index 

1575def makeStringIndex(k=10, name=None): 

1576 return Index(rands_array(nchars=10, size=k), name=name) 

1577 

1578 

1579def makeUnicodeIndex(k=10, name=None): 

1580 return Index(randu_array(nchars=10, size=k), name=name) 

1581 

1582 

1583def makeCategoricalIndex(k=10, n=3, name=None, **kwargs): 

1584 """ make a length k index or n categories """ 

1585 x = rands_array(nchars=4, size=n) 

1586 return CategoricalIndex( 

1587 Categorical.from_codes(np.arange(k) % n, categories=x), name=name, **kwargs 

1588 ) 

1589 

1590 

1591def makeIntervalIndex(k=10, name=None, **kwargs): 

1592 """ make a length k IntervalIndex """ 

1593 x = np.linspace(0, 100, num=(k + 1)) 

1594 return IntervalIndex.from_breaks(x, name=name, **kwargs) 

1595 

1596 

1597def makeBoolIndex(k=10, name=None): 

1598 if k == 1: 

1599 return Index([True], name=name) 

1600 elif k == 2: 

1601 return Index([False, True], name=name) 

1602 return Index([False, True] + [False] * (k - 2), name=name) 

1603 

1604 

1605def makeIntIndex(k=10, name=None): 

1606 return Index(list(range(k)), name=name) 

1607 

1608 

1609def makeUIntIndex(k=10, name=None): 

1610 return Index([2 ** 63 + i for i in range(k)], name=name) 

1611 

1612 

1613def makeRangeIndex(k=10, name=None, **kwargs): 

1614 return RangeIndex(0, k, 1, name=name, **kwargs) 

1615 

1616 

1617def makeFloatIndex(k=10, name=None): 

1618 values = sorted(np.random.random_sample(k)) - np.random.random_sample(1) 

1619 return Index(values * (10 ** np.random.randint(0, 9)), name=name) 

1620 

1621 

1622def makeDateIndex(k=10, freq="B", name=None, **kwargs): 

1623 dt = datetime(2000, 1, 1) 

1624 dr = bdate_range(dt, periods=k, freq=freq, name=name) 

1625 return DatetimeIndex(dr, name=name, **kwargs) 

1626 

1627 

1628def makeTimedeltaIndex(k=10, freq="D", name=None, **kwargs): 

1629 return pd.timedelta_range(start="1 day", periods=k, freq=freq, name=name, **kwargs) 

1630 

1631 

1632def makePeriodIndex(k=10, name=None, **kwargs): 

1633 dt = datetime(2000, 1, 1) 

1634 dr = pd.period_range(start=dt, periods=k, freq="B", name=name, **kwargs) 

1635 return dr 

1636 

1637 

1638def makeMultiIndex(k=10, names=None, **kwargs): 

1639 return MultiIndex.from_product((("foo", "bar"), (1, 2)), names=names, **kwargs) 

1640 

1641 

1642_names = [ 

1643 "Alice", 

1644 "Bob", 

1645 "Charlie", 

1646 "Dan", 

1647 "Edith", 

1648 "Frank", 

1649 "George", 

1650 "Hannah", 

1651 "Ingrid", 

1652 "Jerry", 

1653 "Kevin", 

1654 "Laura", 

1655 "Michael", 

1656 "Norbert", 

1657 "Oliver", 

1658 "Patricia", 

1659 "Quinn", 

1660 "Ray", 

1661 "Sarah", 

1662 "Tim", 

1663 "Ursula", 

1664 "Victor", 

1665 "Wendy", 

1666 "Xavier", 

1667 "Yvonne", 

1668 "Zelda", 

1669] 

1670 

1671 

1672def _make_timeseries(start="2000-01-01", end="2000-12-31", freq="1D", seed=None): 

1673 """ 

1674 Make a DataFrame with a DatetimeIndex 

1675 

1676 Parameters 

1677 ---------- 

1678 start : str or Timestamp, default "2000-01-01" 

1679 The start of the index. Passed to date_range with `freq`. 

1680 end : str or Timestamp, default "2000-12-31" 

1681 The end of the index. Passed to date_range with `freq`. 

1682 freq : str or Freq 

1683 The frequency to use for the DatetimeIndex 

1684 seed : int, optional 

1685 The random state seed. 

1686 

1687 * name : object dtype with string names 

1688 * id : int dtype with 

1689 * x, y : float dtype 

1690 

1691 Examples 

1692 -------- 

1693 >>> _make_timeseries() 

1694 id name x y 

1695 timestamp 

1696 2000-01-01 982 Frank 0.031261 0.986727 

1697 2000-01-02 1025 Edith -0.086358 -0.032920 

1698 2000-01-03 982 Edith 0.473177 0.298654 

1699 2000-01-04 1009 Sarah 0.534344 -0.750377 

1700 2000-01-05 963 Zelda -0.271573 0.054424 

1701 ... ... ... ... ... 

1702 2000-12-27 980 Ingrid -0.132333 -0.422195 

1703 2000-12-28 972 Frank -0.376007 -0.298687 

1704 2000-12-29 1009 Ursula -0.865047 -0.503133 

1705 2000-12-30 1000 Hannah -0.063757 -0.507336 

1706 2000-12-31 972 Tim -0.869120 0.531685 

1707 """ 

1708 index = pd.date_range(start=start, end=end, freq=freq, name="timestamp") 

1709 n = len(index) 

1710 state = np.random.RandomState(seed) 

1711 columns = { 

1712 "name": state.choice(_names, size=n), 

1713 "id": state.poisson(1000, size=n), 

1714 "x": state.rand(n) * 2 - 1, 

1715 "y": state.rand(n) * 2 - 1, 

1716 } 

1717 df = pd.DataFrame(columns, index=index, columns=sorted(columns)) 

1718 if df.index[-1] == end: 

1719 df = df.iloc[:-1] 

1720 return df 

1721 

1722 

1723def all_index_generator(k=10): 

1724 """Generator which can be iterated over to get instances of all the various 

1725 index classes. 

1726 

1727 Parameters 

1728 ---------- 

1729 k: length of each of the index instances 

1730 """ 

1731 all_make_index_funcs = [ 

1732 makeIntIndex, 

1733 makeFloatIndex, 

1734 makeStringIndex, 

1735 makeUnicodeIndex, 

1736 makeDateIndex, 

1737 makePeriodIndex, 

1738 makeTimedeltaIndex, 

1739 makeBoolIndex, 

1740 makeRangeIndex, 

1741 makeIntervalIndex, 

1742 makeCategoricalIndex, 

1743 ] 

1744 for make_index_func in all_make_index_funcs: 

1745 yield make_index_func(k=k) 

1746 

1747 

1748def index_subclass_makers_generator(): 

1749 make_index_funcs = [ 

1750 makeDateIndex, 

1751 makePeriodIndex, 

1752 makeTimedeltaIndex, 

1753 makeRangeIndex, 

1754 makeIntervalIndex, 

1755 makeCategoricalIndex, 

1756 makeMultiIndex, 

1757 ] 

1758 for make_index_func in make_index_funcs: 

1759 yield make_index_func 

1760 

1761 

1762def all_timeseries_index_generator(k=10): 

1763 """Generator which can be iterated over to get instances of all the classes 

1764 which represent time-series. 

1765 

1766 Parameters 

1767 ---------- 

1768 k: length of each of the index instances 

1769 """ 

1770 make_index_funcs = [makeDateIndex, makePeriodIndex, makeTimedeltaIndex] 

1771 for make_index_func in make_index_funcs: 

1772 yield make_index_func(k=k) 

1773 

1774 

1775# make series 

1776def makeFloatSeries(name=None): 

1777 index = makeStringIndex(N) 

1778 return Series(randn(N), index=index, name=name) 

1779 

1780 

1781def makeStringSeries(name=None): 

1782 index = makeStringIndex(N) 

1783 return Series(randn(N), index=index, name=name) 

1784 

1785 

1786def makeObjectSeries(name=None): 

1787 data = makeStringIndex(N) 

1788 data = Index(data, dtype=object) 

1789 index = makeStringIndex(N) 

1790 return Series(data, index=index, name=name) 

1791 

1792 

1793def getSeriesData(): 

1794 index = makeStringIndex(N) 

1795 return {c: Series(randn(N), index=index) for c in getCols(K)} 

1796 

1797 

1798def makeTimeSeries(nper=None, freq="B", name=None): 

1799 if nper is None: 

1800 nper = N 

1801 return Series(randn(nper), index=makeDateIndex(nper, freq=freq), name=name) 

1802 

1803 

1804def makePeriodSeries(nper=None, name=None): 

1805 if nper is None: 

1806 nper = N 

1807 return Series(randn(nper), index=makePeriodIndex(nper), name=name) 

1808 

1809 

1810def getTimeSeriesData(nper=None, freq="B"): 

1811 return {c: makeTimeSeries(nper, freq) for c in getCols(K)} 

1812 

1813 

1814def getPeriodData(nper=None): 

1815 return {c: makePeriodSeries(nper) for c in getCols(K)} 

1816 

1817 

1818# make frame 

1819def makeTimeDataFrame(nper=None, freq="B"): 

1820 data = getTimeSeriesData(nper, freq) 

1821 return DataFrame(data) 

1822 

1823 

1824def makeDataFrame(): 

1825 data = getSeriesData() 

1826 return DataFrame(data) 

1827 

1828 

1829def getMixedTypeDict(): 

1830 index = Index(["a", "b", "c", "d", "e"]) 

1831 

1832 data = { 

1833 "A": [0.0, 1.0, 2.0, 3.0, 4.0], 

1834 "B": [0.0, 1.0, 0.0, 1.0, 0.0], 

1835 "C": ["foo1", "foo2", "foo3", "foo4", "foo5"], 

1836 "D": bdate_range("1/1/2009", periods=5), 

1837 } 

1838 

1839 return index, data 

1840 

1841 

1842def makeMixedDataFrame(): 

1843 return DataFrame(getMixedTypeDict()[1]) 

1844 

1845 

1846def makePeriodFrame(nper=None): 

1847 data = getPeriodData(nper) 

1848 return DataFrame(data) 

1849 

1850 

1851def makeCustomIndex( 

1852 nentries, nlevels, prefix="#", names=False, ndupe_l=None, idx_type=None 

1853): 

1854 """Create an index/multindex with given dimensions, levels, names, etc' 

1855 

1856 nentries - number of entries in index 

1857 nlevels - number of levels (> 1 produces multindex) 

1858 prefix - a string prefix for labels 

1859 names - (Optional), bool or list of strings. if True will use default 

1860 names, if false will use no names, if a list is given, the name of 

1861 each level in the index will be taken from the list. 

1862 ndupe_l - (Optional), list of ints, the number of rows for which the 

1863 label will repeated at the corresponding level, you can specify just 

1864 the first few, the rest will use the default ndupe_l of 1. 

1865 len(ndupe_l) <= nlevels. 

1866 idx_type - "i"/"f"/"s"/"u"/"dt"/"p"/"td". 

1867 If idx_type is not None, `idx_nlevels` must be 1. 

1868 "i"/"f" creates an integer/float index, 

1869 "s"/"u" creates a string/unicode index 

1870 "dt" create a datetime index. 

1871 "td" create a datetime index. 

1872 

1873 if unspecified, string labels will be generated. 

1874 """ 

1875 

1876 if ndupe_l is None: 

1877 ndupe_l = [1] * nlevels 

1878 assert is_sequence(ndupe_l) and len(ndupe_l) <= nlevels 

1879 assert names is None or names is False or names is True or len(names) is nlevels 

1880 assert idx_type is None or ( 

1881 idx_type in ("i", "f", "s", "u", "dt", "p", "td") and nlevels == 1 

1882 ) 

1883 

1884 if names is True: 

1885 # build default names 

1886 names = [prefix + str(i) for i in range(nlevels)] 

1887 if names is False: 

1888 # pass None to index constructor for no name 

1889 names = None 

1890 

1891 # make singleton case uniform 

1892 if isinstance(names, str) and nlevels == 1: 

1893 names = [names] 

1894 

1895 # specific 1D index type requested? 

1896 idx_func = dict( 

1897 i=makeIntIndex, 

1898 f=makeFloatIndex, 

1899 s=makeStringIndex, 

1900 u=makeUnicodeIndex, 

1901 dt=makeDateIndex, 

1902 td=makeTimedeltaIndex, 

1903 p=makePeriodIndex, 

1904 ).get(idx_type) 

1905 if idx_func: 

1906 idx = idx_func(nentries) 

1907 # but we need to fill in the name 

1908 if names: 

1909 idx.name = names[0] 

1910 return idx 

1911 elif idx_type is not None: 

1912 raise ValueError( 

1913 f"{repr(idx_type)} is not a legal value for `idx_type`, " 

1914 "use 'i'/'f'/'s'/'u'/'dt'/'p'/'td'." 

1915 ) 

1916 

1917 if len(ndupe_l) < nlevels: 

1918 ndupe_l.extend([1] * (nlevels - len(ndupe_l))) 

1919 assert len(ndupe_l) == nlevels 

1920 

1921 assert all(x > 0 for x in ndupe_l) 

1922 

1923 tuples = [] 

1924 for i in range(nlevels): 

1925 

1926 def keyfunc(x): 

1927 import re 

1928 

1929 numeric_tuple = re.sub(r"[^\d_]_?", "", x).split("_") 

1930 return [int(num) for num in numeric_tuple] 

1931 

1932 # build a list of lists to create the index from 

1933 div_factor = nentries // ndupe_l[i] + 1 

1934 cnt = Counter() 

1935 for j in range(div_factor): 

1936 label = f"{prefix}_l{i}_g{j}" 

1937 cnt[label] = ndupe_l[i] 

1938 # cute Counter trick 

1939 result = sorted(cnt.elements(), key=keyfunc)[:nentries] 

1940 tuples.append(result) 

1941 

1942 tuples = list(zip(*tuples)) 

1943 

1944 # convert tuples to index 

1945 if nentries == 1: 

1946 # we have a single level of tuples, i.e. a regular Index 

1947 index = Index(tuples[0], name=names[0]) 

1948 elif nlevels == 1: 

1949 name = None if names is None else names[0] 

1950 index = Index((x[0] for x in tuples), name=name) 

1951 else: 

1952 index = MultiIndex.from_tuples(tuples, names=names) 

1953 return index 

1954 

1955 

1956def makeCustomDataframe( 

1957 nrows, 

1958 ncols, 

1959 c_idx_names=True, 

1960 r_idx_names=True, 

1961 c_idx_nlevels=1, 

1962 r_idx_nlevels=1, 

1963 data_gen_f=None, 

1964 c_ndupe_l=None, 

1965 r_ndupe_l=None, 

1966 dtype=None, 

1967 c_idx_type=None, 

1968 r_idx_type=None, 

1969): 

1970 """ 

1971 nrows, ncols - number of data rows/cols 

1972 c_idx_names, idx_names - False/True/list of strings, yields No names , 

1973 default names or uses the provided names for the levels of the 

1974 corresponding index. You can provide a single string when 

1975 c_idx_nlevels ==1. 

1976 c_idx_nlevels - number of levels in columns index. > 1 will yield MultiIndex 

1977 r_idx_nlevels - number of levels in rows index. > 1 will yield MultiIndex 

1978 data_gen_f - a function f(row,col) which return the data value 

1979 at that position, the default generator used yields values of the form 

1980 "RxCy" based on position. 

1981 c_ndupe_l, r_ndupe_l - list of integers, determines the number 

1982 of duplicates for each label at a given level of the corresponding 

1983 index. The default `None` value produces a multiplicity of 1 across 

1984 all levels, i.e. a unique index. Will accept a partial list of length 

1985 N < idx_nlevels, for just the first N levels. If ndupe doesn't divide 

1986 nrows/ncol, the last label might have lower multiplicity. 

1987 dtype - passed to the DataFrame constructor as is, in case you wish to 

1988 have more control in conjunction with a custom `data_gen_f` 

1989 r_idx_type, c_idx_type - "i"/"f"/"s"/"u"/"dt"/"td". 

1990 If idx_type is not None, `idx_nlevels` must be 1. 

1991 "i"/"f" creates an integer/float index, 

1992 "s"/"u" creates a string/unicode index 

1993 "dt" create a datetime index. 

1994 "td" create a timedelta index. 

1995 

1996 if unspecified, string labels will be generated. 

1997 

1998 Examples: 

1999 

2000 # 5 row, 3 columns, default names on both, single index on both axis 

2001 >> makeCustomDataframe(5,3) 

2002 

2003 # make the data a random int between 1 and 100 

2004 >> mkdf(5,3,data_gen_f=lambda r,c:randint(1,100)) 

2005 

2006 # 2-level multiindex on rows with each label duplicated 

2007 # twice on first level, default names on both axis, single 

2008 # index on both axis 

2009 >> a=makeCustomDataframe(5,3,r_idx_nlevels=2,r_ndupe_l=[2]) 

2010 

2011 # DatetimeIndex on row, index with unicode labels on columns 

2012 # no names on either axis 

2013 >> a=makeCustomDataframe(5,3,c_idx_names=False,r_idx_names=False, 

2014 r_idx_type="dt",c_idx_type="u") 

2015 

2016 # 4-level multindex on rows with names provided, 2-level multindex 

2017 # on columns with default labels and default names. 

2018 >> a=makeCustomDataframe(5,3,r_idx_nlevels=4, 

2019 r_idx_names=["FEE","FI","FO","FAM"], 

2020 c_idx_nlevels=2) 

2021 

2022 >> a=mkdf(5,3,r_idx_nlevels=2,c_idx_nlevels=4) 

2023 """ 

2024 

2025 assert c_idx_nlevels > 0 

2026 assert r_idx_nlevels > 0 

2027 assert r_idx_type is None or ( 

2028 r_idx_type in ("i", "f", "s", "u", "dt", "p", "td") and r_idx_nlevels == 1 

2029 ) 

2030 assert c_idx_type is None or ( 

2031 c_idx_type in ("i", "f", "s", "u", "dt", "p", "td") and c_idx_nlevels == 1 

2032 ) 

2033 

2034 columns = makeCustomIndex( 

2035 ncols, 

2036 nlevels=c_idx_nlevels, 

2037 prefix="C", 

2038 names=c_idx_names, 

2039 ndupe_l=c_ndupe_l, 

2040 idx_type=c_idx_type, 

2041 ) 

2042 index = makeCustomIndex( 

2043 nrows, 

2044 nlevels=r_idx_nlevels, 

2045 prefix="R", 

2046 names=r_idx_names, 

2047 ndupe_l=r_ndupe_l, 

2048 idx_type=r_idx_type, 

2049 ) 

2050 

2051 # by default, generate data based on location 

2052 if data_gen_f is None: 

2053 data_gen_f = lambda r, c: f"R{r}C{c}" 

2054 

2055 data = [[data_gen_f(r, c) for c in range(ncols)] for r in range(nrows)] 

2056 

2057 return DataFrame(data, index, columns, dtype=dtype) 

2058 

2059 

2060def _create_missing_idx(nrows, ncols, density, random_state=None): 

2061 if random_state is None: 

2062 random_state = np.random 

2063 else: 

2064 random_state = np.random.RandomState(random_state) 

2065 

2066 # below is cribbed from scipy.sparse 

2067 size = int(np.round((1 - density) * nrows * ncols)) 

2068 # generate a few more to ensure unique values 

2069 min_rows = 5 

2070 fac = 1.02 

2071 extra_size = min(size + min_rows, fac * size) 

2072 

2073 def _gen_unique_rand(rng, _extra_size): 

2074 ind = rng.rand(int(_extra_size)) 

2075 return np.unique(np.floor(ind * nrows * ncols))[:size] 

2076 

2077 ind = _gen_unique_rand(random_state, extra_size) 

2078 while ind.size < size: 

2079 extra_size *= 1.05 

2080 ind = _gen_unique_rand(random_state, extra_size) 

2081 

2082 j = np.floor(ind * 1.0 / nrows).astype(int) 

2083 i = (ind - j * nrows).astype(int) 

2084 return i.tolist(), j.tolist() 

2085 

2086 

2087def makeMissingCustomDataframe( 

2088 nrows, 

2089 ncols, 

2090 density=0.9, 

2091 random_state=None, 

2092 c_idx_names=True, 

2093 r_idx_names=True, 

2094 c_idx_nlevels=1, 

2095 r_idx_nlevels=1, 

2096 data_gen_f=None, 

2097 c_ndupe_l=None, 

2098 r_ndupe_l=None, 

2099 dtype=None, 

2100 c_idx_type=None, 

2101 r_idx_type=None, 

2102): 

2103 """ 

2104 Parameters 

2105 ---------- 

2106 Density : float, optional 

2107 Float in (0, 1) that gives the percentage of non-missing numbers in 

2108 the DataFrame. 

2109 random_state : {np.random.RandomState, int}, optional 

2110 Random number generator or random seed. 

2111 

2112 See makeCustomDataframe for descriptions of the rest of the parameters. 

2113 """ 

2114 df = makeCustomDataframe( 

2115 nrows, 

2116 ncols, 

2117 c_idx_names=c_idx_names, 

2118 r_idx_names=r_idx_names, 

2119 c_idx_nlevels=c_idx_nlevels, 

2120 r_idx_nlevels=r_idx_nlevels, 

2121 data_gen_f=data_gen_f, 

2122 c_ndupe_l=c_ndupe_l, 

2123 r_ndupe_l=r_ndupe_l, 

2124 dtype=dtype, 

2125 c_idx_type=c_idx_type, 

2126 r_idx_type=r_idx_type, 

2127 ) 

2128 

2129 i, j = _create_missing_idx(nrows, ncols, density, random_state) 

2130 df.values[i, j] = np.nan 

2131 return df 

2132 

2133 

2134def makeMissingDataframe(density=0.9, random_state=None): 

2135 df = makeDataFrame() 

2136 i, j = _create_missing_idx(*df.shape, density=density, random_state=random_state) 

2137 df.values[i, j] = np.nan 

2138 return df 

2139 

2140 

2141def optional_args(decorator): 

2142 """allows a decorator to take optional positional and keyword arguments. 

2143 Assumes that taking a single, callable, positional argument means that 

2144 it is decorating a function, i.e. something like this:: 

2145 

2146 @my_decorator 

2147 def function(): pass 

2148 

2149 Calls decorator with decorator(f, *args, **kwargs)""" 

2150 

2151 @wraps(decorator) 

2152 def wrapper(*args, **kwargs): 

2153 def dec(f): 

2154 return decorator(f, *args, **kwargs) 

2155 

2156 is_decorating = not kwargs and len(args) == 1 and callable(args[0]) 

2157 if is_decorating: 

2158 f = args[0] 

2159 args = [] 

2160 return dec(f) 

2161 else: 

2162 return dec 

2163 

2164 return wrapper 

2165 

2166 

2167# skip tests on exceptions with this message 

2168_network_error_messages = ( 

2169 # 'urlopen error timed out', 

2170 # 'timeout: timed out', 

2171 # 'socket.timeout: timed out', 

2172 "timed out", 

2173 "Server Hangup", 

2174 "HTTP Error 503: Service Unavailable", 

2175 "502: Proxy Error", 

2176 "HTTP Error 502: internal error", 

2177 "HTTP Error 502", 

2178 "HTTP Error 503", 

2179 "HTTP Error 403", 

2180 "HTTP Error 400", 

2181 "Temporary failure in name resolution", 

2182 "Name or service not known", 

2183 "Connection refused", 

2184 "certificate verify", 

2185) 

2186 

2187# or this e.errno/e.reason.errno 

2188_network_errno_vals = ( 

2189 101, # Network is unreachable 

2190 111, # Connection refused 

2191 110, # Connection timed out 

2192 104, # Connection reset Error 

2193 54, # Connection reset by peer 

2194 60, # urllib.error.URLError: [Errno 60] Connection timed out 

2195) 

2196 

2197# Both of the above shouldn't mask real issues such as 404's 

2198# or refused connections (changed DNS). 

2199# But some tests (test_data yahoo) contact incredibly flakey 

2200# servers. 

2201 

2202# and conditionally raise on exception types in _get_default_network_errors 

2203 

2204 

2205def _get_default_network_errors(): 

2206 # Lazy import for http.client because it imports many things from the stdlib 

2207 import http.client 

2208 

2209 return (IOError, http.client.HTTPException, TimeoutError) 

2210 

2211 

2212def can_connect(url, error_classes=None): 

2213 """Try to connect to the given url. True if succeeds, False if IOError 

2214 raised 

2215 

2216 Parameters 

2217 ---------- 

2218 url : basestring 

2219 The URL to try to connect to 

2220 

2221 Returns 

2222 ------- 

2223 connectable : bool 

2224 Return True if no IOError (unable to connect) or URLError (bad url) was 

2225 raised 

2226 """ 

2227 

2228 if error_classes is None: 

2229 error_classes = _get_default_network_errors() 

2230 

2231 try: 

2232 with urlopen(url): 

2233 pass 

2234 except error_classes: 

2235 return False 

2236 else: 

2237 return True 

2238 

2239 

2240@optional_args 

2241def network( 

2242 t, 

2243 url="http://www.google.com", 

2244 raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT, 

2245 check_before_test=False, 

2246 error_classes=None, 

2247 skip_errnos=_network_errno_vals, 

2248 _skip_on_messages=_network_error_messages, 

2249): 

2250 """ 

2251 Label a test as requiring network connection and, if an error is 

2252 encountered, only raise if it does not find a network connection. 

2253 

2254 In comparison to ``network``, this assumes an added contract to your test: 

2255 you must assert that, under normal conditions, your test will ONLY fail if 

2256 it does not have network connectivity. 

2257 

2258 You can call this in 3 ways: as a standard decorator, with keyword 

2259 arguments, or with a positional argument that is the url to check. 

2260 

2261 Parameters 

2262 ---------- 

2263 t : callable 

2264 The test requiring network connectivity. 

2265 url : path 

2266 The url to test via ``pandas.io.common.urlopen`` to check 

2267 for connectivity. Defaults to 'http://www.google.com'. 

2268 raise_on_error : bool 

2269 If True, never catches errors. 

2270 check_before_test : bool 

2271 If True, checks connectivity before running the test case. 

2272 error_classes : tuple or Exception 

2273 error classes to ignore. If not in ``error_classes``, raises the error. 

2274 defaults to IOError. Be careful about changing the error classes here. 

2275 skip_errnos : iterable of int 

2276 Any exception that has .errno or .reason.erno set to one 

2277 of these values will be skipped with an appropriate 

2278 message. 

2279 _skip_on_messages: iterable of string 

2280 any exception e for which one of the strings is 

2281 a substring of str(e) will be skipped with an appropriate 

2282 message. Intended to suppress errors where an errno isn't available. 

2283 

2284 Notes 

2285 ----- 

2286 * ``raise_on_error`` supercedes ``check_before_test`` 

2287 

2288 Returns 

2289 ------- 

2290 t : callable 

2291 The decorated test ``t``, with checks for connectivity errors. 

2292 

2293 Example 

2294 ------- 

2295 

2296 Tests decorated with @network will fail if it's possible to make a network 

2297 connection to another URL (defaults to google.com):: 

2298 

2299 >>> from pandas._testing import network 

2300 >>> from pandas.io.common import urlopen 

2301 >>> @network 

2302 ... def test_network(): 

2303 ... with urlopen("rabbit://bonanza.com"): 

2304 ... pass 

2305 Traceback 

2306 ... 

2307 URLError: <urlopen error unknown url type: rabit> 

2308 

2309 You can specify alternative URLs:: 

2310 

2311 >>> @network("http://www.yahoo.com") 

2312 ... def test_something_with_yahoo(): 

2313 ... raise IOError("Failure Message") 

2314 >>> test_something_with_yahoo() 

2315 Traceback (most recent call last): 

2316 ... 

2317 IOError: Failure Message 

2318 

2319 If you set check_before_test, it will check the url first and not run the 

2320 test on failure:: 

2321 

2322 >>> @network("failing://url.blaher", check_before_test=True) 

2323 ... def test_something(): 

2324 ... print("I ran!") 

2325 ... raise ValueError("Failure") 

2326 >>> test_something() 

2327 Traceback (most recent call last): 

2328 ... 

2329 

2330 Errors not related to networking will always be raised. 

2331 """ 

2332 from pytest import skip 

2333 

2334 if error_classes is None: 

2335 error_classes = _get_default_network_errors() 

2336 

2337 t.network = True 

2338 

2339 @wraps(t) 

2340 def wrapper(*args, **kwargs): 

2341 if check_before_test and not raise_on_error: 

2342 if not can_connect(url, error_classes): 

2343 skip() 

2344 try: 

2345 return t(*args, **kwargs) 

2346 except Exception as err: 

2347 errno = getattr(err, "errno", None) 

2348 if not errno and hasattr(errno, "reason"): 

2349 errno = getattr(err.reason, "errno", None) 

2350 

2351 if errno in skip_errnos: 

2352 skip(f"Skipping test due to known errno and error {err}") 

2353 

2354 e_str = str(err) 

2355 

2356 if any(m.lower() in e_str.lower() for m in _skip_on_messages): 

2357 skip( 

2358 f"Skipping test because exception message is known and error {err}" 

2359 ) 

2360 

2361 if not isinstance(err, error_classes): 

2362 raise 

2363 

2364 if raise_on_error or can_connect(url, error_classes): 

2365 raise 

2366 else: 

2367 skip(f"Skipping test due to lack of connectivity and error {err}") 

2368 

2369 return wrapper 

2370 

2371 

2372with_connectivity_check = network 

2373 

2374 

2375@contextmanager 

2376def assert_produces_warning( 

2377 expected_warning=Warning, 

2378 filter_level="always", 

2379 clear=None, 

2380 check_stacklevel=True, 

2381 raise_on_extra_warnings=True, 

2382): 

2383 """ 

2384 Context manager for running code expected to either raise a specific 

2385 warning, or not raise any warnings. Verifies that the code raises the 

2386 expected warning, and that it does not raise any other unexpected 

2387 warnings. It is basically a wrapper around ``warnings.catch_warnings``. 

2388 

2389 Parameters 

2390 ---------- 

2391 expected_warning : {Warning, False, None}, default Warning 

2392 The type of Exception raised. ``exception.Warning`` is the base 

2393 class for all warnings. To check that no warning is returned, 

2394 specify ``False`` or ``None``. 

2395 filter_level : str or None, default "always" 

2396 Specifies whether warnings are ignored, displayed, or turned 

2397 into errors. 

2398 Valid values are: 

2399 

2400 * "error" - turns matching warnings into exceptions 

2401 * "ignore" - discard the warning 

2402 * "always" - always emit a warning 

2403 * "default" - print the warning the first time it is generated 

2404 from each location 

2405 * "module" - print the warning the first time it is generated 

2406 from each module 

2407 * "once" - print the warning the first time it is generated 

2408 

2409 clear : str, default None 

2410 If not ``None`` then remove any previously raised warnings from 

2411 the ``__warningsregistry__`` to ensure that no warning messages are 

2412 suppressed by this context manager. If ``None`` is specified, 

2413 the ``__warningsregistry__`` keeps track of which warnings have been 

2414 shown, and does not show them again. 

2415 check_stacklevel : bool, default True 

2416 If True, displays the line that called the function containing 

2417 the warning to show were the function is called. Otherwise, the 

2418 line that implements the function is displayed. 

2419 raise_on_extra_warnings : bool, default True 

2420 Whether extra warnings not of the type `expected_warning` should 

2421 cause the test to fail. 

2422 

2423 Examples 

2424 -------- 

2425 >>> import warnings 

2426 >>> with assert_produces_warning(): 

2427 ... warnings.warn(UserWarning()) 

2428 ... 

2429 >>> with assert_produces_warning(False): 

2430 ... warnings.warn(RuntimeWarning()) 

2431 ... 

2432 Traceback (most recent call last): 

2433 ... 

2434 AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. 

2435 >>> with assert_produces_warning(UserWarning): 

2436 ... warnings.warn(RuntimeWarning()) 

2437 Traceback (most recent call last): 

2438 ... 

2439 AssertionError: Did not see expected warning of class 'UserWarning'. 

2440 

2441 ..warn:: This is *not* thread-safe. 

2442 """ 

2443 __tracebackhide__ = True 

2444 

2445 with warnings.catch_warnings(record=True) as w: 

2446 

2447 if clear is not None: 

2448 # make sure that we are clearing these warnings 

2449 # if they have happened before 

2450 # to guarantee that we will catch them 

2451 if not is_list_like(clear): 

2452 clear = [clear] 

2453 for m in clear: 

2454 try: 

2455 m.__warningregistry__.clear() 

2456 except AttributeError: 

2457 # module may not have __warningregistry__ 

2458 pass 

2459 

2460 saw_warning = False 

2461 warnings.simplefilter(filter_level) 

2462 yield w 

2463 extra_warnings = [] 

2464 

2465 for actual_warning in w: 

2466 if expected_warning and issubclass( 

2467 actual_warning.category, expected_warning 

2468 ): 

2469 saw_warning = True 

2470 

2471 if check_stacklevel and issubclass( 

2472 actual_warning.category, (FutureWarning, DeprecationWarning) 

2473 ): 

2474 from inspect import getframeinfo, stack 

2475 

2476 caller = getframeinfo(stack()[2][0]) 

2477 msg = ( 

2478 "Warning not set with correct stacklevel. " 

2479 f"File where warning is raised: {actual_warning.filename} != " 

2480 f"{caller.filename}. Warning message: {actual_warning.message}" 

2481 ) 

2482 assert actual_warning.filename == caller.filename, msg 

2483 else: 

2484 extra_warnings.append( 

2485 ( 

2486 actual_warning.category.__name__, 

2487 actual_warning.message, 

2488 actual_warning.filename, 

2489 actual_warning.lineno, 

2490 ) 

2491 ) 

2492 if expected_warning: 

2493 msg = ( 

2494 f"Did not see expected warning of class " 

2495 f"{repr(expected_warning.__name__)}" 

2496 ) 

2497 assert saw_warning, msg 

2498 if raise_on_extra_warnings and extra_warnings: 

2499 raise AssertionError( 

2500 f"Caused unexpected warning(s): {repr(extra_warnings)}" 

2501 ) 

2502 

2503 

2504class RNGContext: 

2505 """ 

2506 Context manager to set the numpy random number generator speed. Returns 

2507 to the original value upon exiting the context manager. 

2508 

2509 Parameters 

2510 ---------- 

2511 seed : int 

2512 Seed for numpy.random.seed 

2513 

2514 Examples 

2515 -------- 

2516 

2517 with RNGContext(42): 

2518 np.random.randn() 

2519 """ 

2520 

2521 def __init__(self, seed): 

2522 self.seed = seed 

2523 

2524 def __enter__(self): 

2525 

2526 self.start_state = np.random.get_state() 

2527 np.random.seed(self.seed) 

2528 

2529 def __exit__(self, exc_type, exc_value, traceback): 

2530 

2531 np.random.set_state(self.start_state) 

2532 

2533 

2534@contextmanager 

2535def with_csv_dialect(name, **kwargs): 

2536 """ 

2537 Context manager to temporarily register a CSV dialect for parsing CSV. 

2538 

2539 Parameters 

2540 ---------- 

2541 name : str 

2542 The name of the dialect. 

2543 kwargs : mapping 

2544 The parameters for the dialect. 

2545 

2546 Raises 

2547 ------ 

2548 ValueError : the name of the dialect conflicts with a builtin one. 

2549 

2550 See Also 

2551 -------- 

2552 csv : Python's CSV library. 

2553 """ 

2554 import csv 

2555 

2556 _BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"} 

2557 

2558 if name in _BUILTIN_DIALECTS: 

2559 raise ValueError("Cannot override builtin dialect.") 

2560 

2561 csv.register_dialect(name, **kwargs) 

2562 yield 

2563 csv.unregister_dialect(name) 

2564 

2565 

2566@contextmanager 

2567def use_numexpr(use, min_elements=None): 

2568 from pandas.core.computation import expressions as expr 

2569 

2570 if min_elements is None: 

2571 min_elements = expr._MIN_ELEMENTS 

2572 

2573 olduse = expr._USE_NUMEXPR 

2574 oldmin = expr._MIN_ELEMENTS 

2575 expr.set_use_numexpr(use) 

2576 expr._MIN_ELEMENTS = min_elements 

2577 yield 

2578 expr._MIN_ELEMENTS = oldmin 

2579 expr.set_use_numexpr(olduse) 

2580 

2581 

2582def test_parallel(num_threads=2, kwargs_list=None): 

2583 """Decorator to run the same function multiple times in parallel. 

2584 

2585 Parameters 

2586 ---------- 

2587 num_threads : int, optional 

2588 The number of times the function is run in parallel. 

2589 kwargs_list : list of dicts, optional 

2590 The list of kwargs to update original 

2591 function kwargs on different threads. 

2592 Notes 

2593 ----- 

2594 This decorator does not pass the return value of the decorated function. 

2595 

2596 Original from scikit-image: 

2597 

2598 https://github.com/scikit-image/scikit-image/pull/1519 

2599 

2600 """ 

2601 

2602 assert num_threads > 0 

2603 has_kwargs_list = kwargs_list is not None 

2604 if has_kwargs_list: 

2605 assert len(kwargs_list) == num_threads 

2606 import threading 

2607 

2608 def wrapper(func): 

2609 @wraps(func) 

2610 def inner(*args, **kwargs): 

2611 if has_kwargs_list: 

2612 update_kwargs = lambda i: dict(kwargs, **kwargs_list[i]) 

2613 else: 

2614 update_kwargs = lambda i: kwargs 

2615 threads = [] 

2616 for i in range(num_threads): 

2617 updated_kwargs = update_kwargs(i) 

2618 thread = threading.Thread(target=func, args=args, kwargs=updated_kwargs) 

2619 threads.append(thread) 

2620 for thread in threads: 

2621 thread.start() 

2622 for thread in threads: 

2623 thread.join() 

2624 

2625 return inner 

2626 

2627 return wrapper 

2628 

2629 

2630class SubclassedSeries(Series): 

2631 _metadata = ["testattr", "name"] 

2632 

2633 @property 

2634 def _constructor(self): 

2635 return SubclassedSeries 

2636 

2637 @property 

2638 def _constructor_expanddim(self): 

2639 return SubclassedDataFrame 

2640 

2641 

2642class SubclassedDataFrame(DataFrame): 

2643 _metadata = ["testattr"] 

2644 

2645 @property 

2646 def _constructor(self): 

2647 return SubclassedDataFrame 

2648 

2649 @property 

2650 def _constructor_sliced(self): 

2651 return SubclassedSeries 

2652 

2653 

2654class SubclassedCategorical(Categorical): 

2655 @property 

2656 def _constructor(self): 

2657 return SubclassedCategorical 

2658 

2659 

2660@contextmanager 

2661def set_timezone(tz: str): 

2662 """ 

2663 Context manager for temporarily setting a timezone. 

2664 

2665 Parameters 

2666 ---------- 

2667 tz : str 

2668 A string representing a valid timezone. 

2669 

2670 Examples 

2671 -------- 

2672 

2673 >>> from datetime import datetime 

2674 >>> from dateutil.tz import tzlocal 

2675 >>> tzlocal().tzname(datetime.now()) 

2676 'IST' 

2677 

2678 >>> with set_timezone('US/Eastern'): 

2679 ... tzlocal().tzname(datetime.now()) 

2680 ... 

2681 'EDT' 

2682 """ 

2683 

2684 import os 

2685 import time 

2686 

2687 def setTZ(tz): 

2688 if tz is None: 

2689 try: 

2690 del os.environ["TZ"] 

2691 except KeyError: 

2692 pass 

2693 else: 

2694 os.environ["TZ"] = tz 

2695 time.tzset() 

2696 

2697 orig_tz = os.environ.get("TZ") 

2698 setTZ(tz) 

2699 try: 

2700 yield 

2701 finally: 

2702 setTZ(orig_tz) 

2703 

2704 

2705def _make_skipna_wrapper(alternative, skipna_alternative=None): 

2706 """ 

2707 Create a function for calling on an array. 

2708 

2709 Parameters 

2710 ---------- 

2711 alternative : function 

2712 The function to be called on the array with no NaNs. 

2713 Only used when 'skipna_alternative' is None. 

2714 skipna_alternative : function 

2715 The function to be called on the original array 

2716 

2717 Returns 

2718 ------- 

2719 function 

2720 """ 

2721 if skipna_alternative: 

2722 

2723 def skipna_wrapper(x): 

2724 return skipna_alternative(x.values) 

2725 

2726 else: 

2727 

2728 def skipna_wrapper(x): 

2729 nona = x.dropna() 

2730 if len(nona) == 0: 

2731 return np.nan 

2732 return alternative(nona) 

2733 

2734 return skipna_wrapper 

2735 

2736 

2737def convert_rows_list_to_csv_str(rows_list: List[str]): 

2738 """ 

2739 Convert list of CSV rows to single CSV-formatted string for current OS. 

2740 

2741 This method is used for creating expected value of to_csv() method. 

2742 

2743 Parameters 

2744 ---------- 

2745 rows_list : List[str] 

2746 Each element represents the row of csv. 

2747 

2748 Returns 

2749 ------- 

2750 str 

2751 Expected output of to_csv() in current OS. 

2752 """ 

2753 sep = os.linesep 

2754 expected = sep.join(rows_list) + sep 

2755 return expected