Coverage for maze_dataset\tokenization\save_hashes.py: 0%
31 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""generate and save the hashes of all supported tokenizers
3calls `maze_dataset.tokenization.all_tokenizers.save_hashes()`
5Usage:
7To save to the default location (inside package, `maze_dataset/tokenization/MazeTokenizerModular_hashes.npy`):
8```bash
9python -m maze_dataset.tokenization.save_hashes
10```
12to save to a custom location:
13```bash
14python -m maze_dataset.tokenization.save_hashes /path/to/save/to.npy
15```
17to check hashes shipped with the package:
18```bash
19python -m maze_dataset.tokenization.save_hashes --check
20```
22"""
24from pathlib import Path
26import numpy as np
27from muutils.spinner import SpinnerContext
29import maze_dataset.tokenization.all_tokenizers as all_tokenizers
30from maze_dataset.tokenization.maze_tokenizer import (
31 _load_tokenizer_hashes,
32 get_all_tokenizer_hashes,
33)
35if __name__ == "__main__":
36 # parse args
37 # ==================================================
38 import argparse
40 parser: argparse.ArgumentParser = argparse.ArgumentParser(
41 description="generate and save the hashes of all supported tokenizers"
42 )
44 parser.add_argument("path", type=str, nargs="?", help="path to save the hashes to")
45 parser.add_argument(
46 "--quiet", "-q", action="store_true", help="disable progress bar and spinner"
47 )
48 parser.add_argument(
49 "--parallelize", "-p", action="store_true", help="parallelize the computation"
50 )
51 parser.add_argument(
52 "--check",
53 "-c",
54 action="store_true",
55 help="save to temp location, then compare to existing",
56 )
58 args: argparse.Namespace = parser.parse_args()
60 if not args.check:
61 # write new hashes
62 # ==================================================
63 all_tokenizers.save_hashes(
64 path=args.path,
65 verbose=not args.quiet,
66 parallelize=args.parallelize,
67 )
69 else:
70 # check hashes only
71 # ==================================================
73 # set up path
74 if args.path is not None:
75 raise ValueError("cannot use --check with a custom path")
76 temp_path: Path = Path("tests/_temp/tok_hashes.npz")
77 temp_path.parent.mkdir(parents=True, exist_ok=True)
79 # generate and save to temp location
80 returned_hashes: np.ndarray = all_tokenizers.save_hashes(
81 path=temp_path,
82 verbose=not args.quiet,
83 parallelize=args.parallelize,
84 )
86 # load saved hashes
87 with SpinnerContext(
88 spinner_chars="square_dot",
89 update_interval=0.5,
90 message="loading saved hashes...",
91 ):
92 read_hashes: np.ndarray = np.load(temp_path)["hashes"]
93 read_hashes_pkg: np.ndarray = _load_tokenizer_hashes()
94 read_hashes_wrapped: np.ndarray = get_all_tokenizer_hashes()
96 # compare
97 with SpinnerContext(
98 spinner_chars="square_dot",
99 update_interval=0.01,
100 message="checking hashes: ",
101 format_string="\r{spinner} ({elapsed_time:.2f}s) {message}{value} ",
102 format_string_when_updated=True,
103 ) as sp:
104 sp.update_value("returned vs read")
105 assert np.array_equal(returned_hashes, read_hashes)
106 sp.update_value("returned vs _load_tokenizer_hashes")
107 assert np.array_equal(returned_hashes, read_hashes_pkg)
108 sp.update_value("returned vs get_all_tokenizer_hashes()")
109 assert np.array_equal(read_hashes, read_hashes_wrapped)