Coverage for maze_dataset/tokenization/save_hashes.py: 0%
31 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
1"""generate and save the hashes of all supported tokenizers
3calls `maze_dataset.tokenization.all_tokenizers.save_hashes()`
5Usage:
7To save to the default location (inside package, `maze_dataset/tokenization/MazeTokenizerModular_hashes.npy`):
8```bash
9python -m maze_dataset.tokenization.save_hashes
10```
12to save to a custom location:
13```bash
14python -m maze_dataset.tokenization.save_hashes /path/to/save/to.npy
15```
17to check hashes shipped with the package:
18```bash
19python -m maze_dataset.tokenization.save_hashes --check
20```
22"""
24from pathlib import Path
26import numpy as np
27from muutils.spinner import SpinnerContext
29from maze_dataset.tokenization import all_tokenizers
30from maze_dataset.tokenization.maze_tokenizer import (
31 _load_tokenizer_hashes,
32 get_all_tokenizer_hashes,
33)
35if __name__ == "__main__":
36 # parse args
37 # ==================================================
38 import argparse
40 parser: argparse.ArgumentParser = argparse.ArgumentParser(
41 description="generate and save the hashes of all supported tokenizers",
42 )
44 parser.add_argument("path", type=str, nargs="?", help="path to save the hashes to")
45 parser.add_argument(
46 "--quiet",
47 "-q",
48 action="store_true",
49 help="disable progress bar and spinner",
50 )
51 parser.add_argument(
52 "--parallelize",
53 "-p",
54 action="store_true",
55 help="parallelize the computation",
56 )
57 parser.add_argument(
58 "--check",
59 "-c",
60 action="store_true",
61 help="save to temp location, then compare to existing",
62 )
64 args: argparse.Namespace = parser.parse_args()
66 if not args.check:
67 # write new hashes
68 # ==================================================
69 all_tokenizers.save_hashes(
70 path=args.path,
71 verbose=not args.quiet,
72 parallelize=args.parallelize,
73 )
75 else:
76 # check hashes only
77 # ==================================================
79 # set up path
80 if args.path is not None:
81 raise ValueError("cannot use --check with a custom path")
82 temp_path: Path = Path("tests/_temp/tok_hashes.npz")
83 temp_path.parent.mkdir(parents=True, exist_ok=True)
85 # generate and save to temp location
86 returned_hashes: np.ndarray = all_tokenizers.save_hashes(
87 path=temp_path,
88 verbose=not args.quiet,
89 parallelize=args.parallelize,
90 )
92 # load saved hashes
93 with SpinnerContext(
94 spinner_chars="square_dot",
95 update_interval=0.5,
96 message="loading saved hashes...",
97 ):
98 read_hashes: np.ndarray = np.load(temp_path)["hashes"]
99 read_hashes_pkg: np.ndarray = _load_tokenizer_hashes()
100 read_hashes_wrapped: np.ndarray = get_all_tokenizer_hashes()
102 # compare
103 with SpinnerContext(
104 spinner_chars="square_dot",
105 update_interval=0.01,
106 message="checking hashes: ",
107 format_string="\r{spinner} ({elapsed_time:.2f}s) {message}{value} ",
108 format_string_when_updated=True,
109 ) as sp:
110 sp.update_value("returned vs read")
111 assert np.array_equal(returned_hashes, read_hashes)
112 sp.update_value("returned vs _load_tokenizer_hashes")
113 assert np.array_equal(returned_hashes, read_hashes_pkg)
114 sp.update_value("returned vs get_all_tokenizer_hashes()")
115 assert np.array_equal(read_hashes, read_hashes_wrapped)