Coverage for maze_dataset/tokenization/save_hashes.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 00:49 -0600

1"""generate and save the hashes of all supported tokenizers 

2 

3calls `maze_dataset.tokenization.all_tokenizers.save_hashes()` 

4 

5Usage: 

6 

7To save to the default location (inside package, `maze_dataset/tokenization/MazeTokenizerModular_hashes.npy`): 

8```bash 

9python -m maze_dataset.tokenization.save_hashes 

10``` 

11 

12to save to a custom location: 

13```bash 

14python -m maze_dataset.tokenization.save_hashes /path/to/save/to.npy 

15``` 

16 

17to check hashes shipped with the package: 

18```bash 

19python -m maze_dataset.tokenization.save_hashes --check 

20``` 

21 

22""" 

23 

24from pathlib import Path 

25 

26import numpy as np 

27from muutils.spinner import SpinnerContext 

28 

29from maze_dataset.tokenization import all_tokenizers 

30from maze_dataset.tokenization.maze_tokenizer import ( 

31 _load_tokenizer_hashes, 

32 get_all_tokenizer_hashes, 

33) 

34 

35if __name__ == "__main__": 

36 # parse args 

37 # ================================================== 

38 import argparse 

39 

40 parser: argparse.ArgumentParser = argparse.ArgumentParser( 

41 description="generate and save the hashes of all supported tokenizers", 

42 ) 

43 

44 parser.add_argument("path", type=str, nargs="?", help="path to save the hashes to") 

45 parser.add_argument( 

46 "--quiet", 

47 "-q", 

48 action="store_true", 

49 help="disable progress bar and spinner", 

50 ) 

51 parser.add_argument( 

52 "--parallelize", 

53 "-p", 

54 action="store_true", 

55 help="parallelize the computation", 

56 ) 

57 parser.add_argument( 

58 "--check", 

59 "-c", 

60 action="store_true", 

61 help="save to temp location, then compare to existing", 

62 ) 

63 

64 args: argparse.Namespace = parser.parse_args() 

65 

66 if not args.check: 

67 # write new hashes 

68 # ================================================== 

69 all_tokenizers.save_hashes( 

70 path=args.path, 

71 verbose=not args.quiet, 

72 parallelize=args.parallelize, 

73 ) 

74 

75 else: 

76 # check hashes only 

77 # ================================================== 

78 

79 # set up path 

80 if args.path is not None: 

81 raise ValueError("cannot use --check with a custom path") 

82 temp_path: Path = Path("tests/_temp/tok_hashes.npz") 

83 temp_path.parent.mkdir(parents=True, exist_ok=True) 

84 

85 # generate and save to temp location 

86 returned_hashes: np.ndarray = all_tokenizers.save_hashes( 

87 path=temp_path, 

88 verbose=not args.quiet, 

89 parallelize=args.parallelize, 

90 ) 

91 

92 # load saved hashes 

93 with SpinnerContext( 

94 spinner_chars="square_dot", 

95 update_interval=0.5, 

96 message="loading saved hashes...", 

97 ): 

98 read_hashes: np.ndarray = np.load(temp_path)["hashes"] 

99 read_hashes_pkg: np.ndarray = _load_tokenizer_hashes() 

100 read_hashes_wrapped: np.ndarray = get_all_tokenizer_hashes() 

101 

102 # compare 

103 with SpinnerContext( 

104 spinner_chars="square_dot", 

105 update_interval=0.01, 

106 message="checking hashes: ", 

107 format_string="\r{spinner} ({elapsed_time:.2f}s) {message}{value} ", 

108 format_string_when_updated=True, 

109 ) as sp: 

110 sp.update_value("returned vs read") 

111 assert np.array_equal(returned_hashes, read_hashes) 

112 sp.update_value("returned vs _load_tokenizer_hashes") 

113 assert np.array_equal(returned_hashes, read_hashes_pkg) 

114 sp.update_value("returned vs get_all_tokenizer_hashes()") 

115 assert np.array_equal(read_hashes, read_hashes_wrapped)