Coverage for maze_dataset\plotting\print_tokens.py: 36%
77 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""Functions to print tokens with colors in different formats
3you can color the tokens by their:
5- type (i.e. adjacency list, origin, target, path) using `color_maze_tokens_AOTP`
6- custom weights (i.e. attention weights) using `color_tokens_cmap`
7- entirely custom colors using `color_tokens_rgb`
9and the output can be in different formats, specified by `FormatType` (html, latex, terminal)
11"""
13import html
14import textwrap
15from typing import Literal, Sequence
17import matplotlib
18import numpy as np
19from IPython.display import HTML, display
20from jaxtyping import UInt8
21from muutils.misc import flatten
23from maze_dataset.constants import SPECIAL_TOKENS
24from maze_dataset.token_utils import tokens_between
26RGBArray = UInt8[np.ndarray, "n 3"]
27"1D array of RGB values"
29FormatType = Literal["html", "latex", "terminal", None]
30"output format for the tokens"
32TEMPLATES: dict[FormatType, str] = {
33 "html": '<span style="color: black; background-color: rgb({clr})"> {tok} </span>',
34 "latex": "\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}",
35 "terminal": "\033[30m\033[48;2;{clr}m{tok}\033[0m",
36}
37"templates of printing tokens in different formats"
39_COLOR_JOIN: dict[FormatType, str] = {
40 "html": ",",
41 "latex": ",",
42 "terminal": ";",
43}
44"joiner for colors in different formats"
47def _escape_tok(
48 tok: str,
49 fmt: FormatType,
50) -> str:
51 "escape token based on format"
52 if fmt == "html":
53 return html.escape(tok)
54 elif fmt == "latex":
55 return tok.replace("_", "\\_").replace("#", "\\#")
56 elif fmt == "terminal":
57 return tok
58 else:
59 raise ValueError(f"Unexpected format: {fmt}")
62def color_tokens_rgb(
63 tokens: list,
64 colors: Sequence[Sequence[int]],
65 fmt: FormatType = "html",
66 template: str | None = None,
67 clr_join: str | None = None,
68 max_length: int | None = None,
69) -> str:
70 """color tokens from a list with an RGB color array
72 tokens will not be escaped if `fmt` is None
74 # Parameters:
75 - `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length.
76 """
77 # process format
78 if fmt is None:
79 assert template is not None
80 assert clr_join is not None
81 else:
82 assert template is None
83 assert clr_join is None
84 template = TEMPLATES[fmt]
85 clr_join = _COLOR_JOIN[fmt]
87 if max_length is not None:
88 wrapped = list(
89 map(
90 lambda x: textwrap.wrap(
91 x, width=max_length, break_long_words=False, break_on_hyphens=False
92 ),
93 tokens,
94 )
95 )
96 colors = list(
97 flatten(
98 [[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))],
99 levels_to_flatten=1,
100 )
101 )
102 wrapped = list(flatten(wrapped, levels_to_flatten=1))
103 tokens = wrapped
105 # put everything together
106 output = [
107 template.format(
108 clr=clr_join.join(map(str, map(int, clr))),
109 tok=_escape_tok(tok, fmt),
110 )
111 for tok, clr in zip(tokens, colors)
112 ]
114 return " ".join(output)
117def color_tokens_cmap(
118 tokens: list[str],
119 weights: Sequence[float],
120 cmap: str | matplotlib.colors.Colormap = "Blues",
121 fmt: FormatType = "html",
122 template: str | None = None,
123 labels: bool = False,
124):
125 "color tokens given a list of weights and a colormap"
126 assert len(tokens) == len(weights), f"{len(tokens)} != {len(weights)}"
127 weights = np.array(weights)
128 # normalize weights to [0, 1]
129 weights_norm = matplotlib.colors.Normalize()(weights)
131 if isinstance(cmap, str):
132 cmap = matplotlib.colormaps.get_cmap(cmap)
134 colors: RGBArray = cmap(weights_norm)[:, :3] * 255
136 output: str = color_tokens_rgb(
137 tokens=tokens,
138 colors=colors,
139 fmt=fmt,
140 template=template,
141 )
143 if labels:
144 if fmt != "terminal":
145 raise NotImplementedError("labels only supported for terminal")
146 # align labels with the tokens
147 output += "\n"
148 for tok, weight in zip(tokens, weights):
149 # 2 decimal points, left-aligned and trailing spaces to match token length
150 weight_str: str = f"{weight:.1f}"
151 # omit if longer than token
152 if len(weight_str) > len(tok):
153 weight_str = " " * len(tok)
154 else:
155 weight_str = weight_str.ljust(len(tok))
156 output += f"{weight_str} "
158 return output
161# colors roughly made to be similar to visual representation
162_MAZE_TOKENS_DEFAULT_COLORS: dict[tuple[str, str], tuple[int, int, int]] = {
163 (SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.ADJLIST_END): (
164 217,
165 210,
166 233,
167 ), # purple
168 (SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END): (217, 234, 211), # green
169 (SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END): (234, 209, 220), # red
170 (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END): (207, 226, 243), # blue
171}
172"default colors for maze tokens, roughly matches the format of `as_pixels`"
175def color_maze_tokens_AOTP(
176 tokens: list[str], fmt: FormatType = "html", template: str | None = None, **kwargs
177) -> str:
178 """color tokens assuming AOTP format
180 i.e: adjaceny list, origin, target, path
182 """
183 output: list[str] = [
184 " ".join(
185 tokens_between(
186 tokens, start_tok, end_tok, include_start=True, include_end=True
187 )
188 )
189 for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS.keys()
190 ]
192 colors: RGBArray = np.array(
193 list(_MAZE_TOKENS_DEFAULT_COLORS.values()), dtype=np.uint8
194 )
196 return color_tokens_rgb(
197 tokens=output, colors=colors, fmt=fmt, template=template, **kwargs
198 )
201def display_html(html: str):
202 display(HTML(html))
205def display_color_tokens_rgb(
206 tokens: list[str],
207 colors: RGBArray,
208) -> None:
209 html: str = color_tokens_rgb(tokens, colors, fmt="html")
210 display_html(html)
213def display_color_tokens_cmap(
214 tokens: list[str],
215 weights: Sequence[float],
216 cmap: str | matplotlib.colors.Colormap = "Blues",
217) -> None:
218 html: str = color_tokens_cmap(tokens, weights, cmap)
219 display_html(html)
222def display_color_maze_tokens_AOTP(
223 tokens: list[str],
224) -> None:
225 html: str = color_maze_tokens_AOTP(tokens)
226 display_html(html)