Coverage for maze_dataset\tokenization\save_hashes.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""generate and save the hashes of all supported tokenizers 

2 

3calls `maze_dataset.tokenization.all_tokenizers.save_hashes()` 

4 

5Usage: 

6 

7To save to the default location (inside package, `maze_dataset/tokenization/MazeTokenizerModular_hashes.npy`): 

8```bash 

9python -m maze_dataset.tokenization.save_hashes 

10``` 

11 

12to save to a custom location: 

13```bash 

14python -m maze_dataset.tokenization.save_hashes /path/to/save/to.npy 

15``` 

16 

17to check hashes shipped with the package: 

18```bash 

19python -m maze_dataset.tokenization.save_hashes --check 

20``` 

21 

22""" 

23 

24from pathlib import Path 

25 

26import numpy as np 

27from muutils.spinner import SpinnerContext 

28 

29import maze_dataset.tokenization.all_tokenizers as all_tokenizers 

30from maze_dataset.tokenization.maze_tokenizer import ( 

31 _load_tokenizer_hashes, 

32 get_all_tokenizer_hashes, 

33) 

34 

35if __name__ == "__main__": 

36 # parse args 

37 # ================================================== 

38 import argparse 

39 

40 parser: argparse.ArgumentParser = argparse.ArgumentParser( 

41 description="generate and save the hashes of all supported tokenizers" 

42 ) 

43 

44 parser.add_argument("path", type=str, nargs="?", help="path to save the hashes to") 

45 parser.add_argument( 

46 "--quiet", "-q", action="store_true", help="disable progress bar and spinner" 

47 ) 

48 parser.add_argument( 

49 "--parallelize", "-p", action="store_true", help="parallelize the computation" 

50 ) 

51 parser.add_argument( 

52 "--check", 

53 "-c", 

54 action="store_true", 

55 help="save to temp location, then compare to existing", 

56 ) 

57 

58 args: argparse.Namespace = parser.parse_args() 

59 

60 if not args.check: 

61 # write new hashes 

62 # ================================================== 

63 all_tokenizers.save_hashes( 

64 path=args.path, 

65 verbose=not args.quiet, 

66 parallelize=args.parallelize, 

67 ) 

68 

69 else: 

70 # check hashes only 

71 # ================================================== 

72 

73 # set up path 

74 if args.path is not None: 

75 raise ValueError("cannot use --check with a custom path") 

76 temp_path: Path = Path("tests/_temp/tok_hashes.npz") 

77 temp_path.parent.mkdir(parents=True, exist_ok=True) 

78 

79 # generate and save to temp location 

80 returned_hashes: np.ndarray = all_tokenizers.save_hashes( 

81 path=temp_path, 

82 verbose=not args.quiet, 

83 parallelize=args.parallelize, 

84 ) 

85 

86 # load saved hashes 

87 with SpinnerContext( 

88 spinner_chars="square_dot", 

89 update_interval=0.5, 

90 message="loading saved hashes...", 

91 ): 

92 read_hashes: np.ndarray = np.load(temp_path)["hashes"] 

93 read_hashes_pkg: np.ndarray = _load_tokenizer_hashes() 

94 read_hashes_wrapped: np.ndarray = get_all_tokenizer_hashes() 

95 

96 # compare 

97 with SpinnerContext( 

98 spinner_chars="square_dot", 

99 update_interval=0.01, 

100 message="checking hashes: ", 

101 format_string="\r{spinner} ({elapsed_time:.2f}s) {message}{value} ", 

102 format_string_when_updated=True, 

103 ) as sp: 

104 sp.update_value("returned vs read") 

105 assert np.array_equal(returned_hashes, read_hashes) 

106 sp.update_value("returned vs _load_tokenizer_hashes") 

107 assert np.array_equal(returned_hashes, read_hashes_pkg) 

108 sp.update_value("returned vs get_all_tokenizer_hashes()") 

109 assert np.array_equal(read_hashes, read_hashes_wrapped)