Coverage for maze_dataset\plotting\print_tokens.py: 36%

77 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""Functions to print tokens with colors in different formats 

2 

3you can color the tokens by their: 

4 

5- type (i.e. adjacency list, origin, target, path) using `color_maze_tokens_AOTP` 

6- custom weights (i.e. attention weights) using `color_tokens_cmap` 

7- entirely custom colors using `color_tokens_rgb` 

8 

9and the output can be in different formats, specified by `FormatType` (html, latex, terminal) 

10 

11""" 

12 

13import html 

14import textwrap 

15from typing import Literal, Sequence 

16 

17import matplotlib 

18import numpy as np 

19from IPython.display import HTML, display 

20from jaxtyping import UInt8 

21from muutils.misc import flatten 

22 

23from maze_dataset.constants import SPECIAL_TOKENS 

24from maze_dataset.token_utils import tokens_between 

25 

26RGBArray = UInt8[np.ndarray, "n 3"] 

27"1D array of RGB values" 

28 

29FormatType = Literal["html", "latex", "terminal", None] 

30"output format for the tokens" 

31 

32TEMPLATES: dict[FormatType, str] = { 

33 "html": '<span style="color: black; background-color: rgb({clr})">&nbsp{tok}&nbsp</span>', 

34 "latex": "\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}", 

35 "terminal": "\033[30m\033[48;2;{clr}m{tok}\033[0m", 

36} 

37"templates of printing tokens in different formats" 

38 

39_COLOR_JOIN: dict[FormatType, str] = { 

40 "html": ",", 

41 "latex": ",", 

42 "terminal": ";", 

43} 

44"joiner for colors in different formats" 

45 

46 

47def _escape_tok( 

48 tok: str, 

49 fmt: FormatType, 

50) -> str: 

51 "escape token based on format" 

52 if fmt == "html": 

53 return html.escape(tok) 

54 elif fmt == "latex": 

55 return tok.replace("_", "\\_").replace("#", "\\#") 

56 elif fmt == "terminal": 

57 return tok 

58 else: 

59 raise ValueError(f"Unexpected format: {fmt}") 

60 

61 

62def color_tokens_rgb( 

63 tokens: list, 

64 colors: Sequence[Sequence[int]], 

65 fmt: FormatType = "html", 

66 template: str | None = None, 

67 clr_join: str | None = None, 

68 max_length: int | None = None, 

69) -> str: 

70 """color tokens from a list with an RGB color array 

71 

72 tokens will not be escaped if `fmt` is None 

73 

74 # Parameters: 

75 - `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length. 

76 """ 

77 # process format 

78 if fmt is None: 

79 assert template is not None 

80 assert clr_join is not None 

81 else: 

82 assert template is None 

83 assert clr_join is None 

84 template = TEMPLATES[fmt] 

85 clr_join = _COLOR_JOIN[fmt] 

86 

87 if max_length is not None: 

88 wrapped = list( 

89 map( 

90 lambda x: textwrap.wrap( 

91 x, width=max_length, break_long_words=False, break_on_hyphens=False 

92 ), 

93 tokens, 

94 ) 

95 ) 

96 colors = list( 

97 flatten( 

98 [[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))], 

99 levels_to_flatten=1, 

100 ) 

101 ) 

102 wrapped = list(flatten(wrapped, levels_to_flatten=1)) 

103 tokens = wrapped 

104 

105 # put everything together 

106 output = [ 

107 template.format( 

108 clr=clr_join.join(map(str, map(int, clr))), 

109 tok=_escape_tok(tok, fmt), 

110 ) 

111 for tok, clr in zip(tokens, colors) 

112 ] 

113 

114 return " ".join(output) 

115 

116 

117def color_tokens_cmap( 

118 tokens: list[str], 

119 weights: Sequence[float], 

120 cmap: str | matplotlib.colors.Colormap = "Blues", 

121 fmt: FormatType = "html", 

122 template: str | None = None, 

123 labels: bool = False, 

124): 

125 "color tokens given a list of weights and a colormap" 

126 assert len(tokens) == len(weights), f"{len(tokens)} != {len(weights)}" 

127 weights = np.array(weights) 

128 # normalize weights to [0, 1] 

129 weights_norm = matplotlib.colors.Normalize()(weights) 

130 

131 if isinstance(cmap, str): 

132 cmap = matplotlib.colormaps.get_cmap(cmap) 

133 

134 colors: RGBArray = cmap(weights_norm)[:, :3] * 255 

135 

136 output: str = color_tokens_rgb( 

137 tokens=tokens, 

138 colors=colors, 

139 fmt=fmt, 

140 template=template, 

141 ) 

142 

143 if labels: 

144 if fmt != "terminal": 

145 raise NotImplementedError("labels only supported for terminal") 

146 # align labels with the tokens 

147 output += "\n" 

148 for tok, weight in zip(tokens, weights): 

149 # 2 decimal points, left-aligned and trailing spaces to match token length 

150 weight_str: str = f"{weight:.1f}" 

151 # omit if longer than token 

152 if len(weight_str) > len(tok): 

153 weight_str = " " * len(tok) 

154 else: 

155 weight_str = weight_str.ljust(len(tok)) 

156 output += f"{weight_str} " 

157 

158 return output 

159 

160 

161# colors roughly made to be similar to visual representation 

162_MAZE_TOKENS_DEFAULT_COLORS: dict[tuple[str, str], tuple[int, int, int]] = { 

163 (SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.ADJLIST_END): ( 

164 217, 

165 210, 

166 233, 

167 ), # purple 

168 (SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END): (217, 234, 211), # green 

169 (SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END): (234, 209, 220), # red 

170 (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END): (207, 226, 243), # blue 

171} 

172"default colors for maze tokens, roughly matches the format of `as_pixels`" 

173 

174 

175def color_maze_tokens_AOTP( 

176 tokens: list[str], fmt: FormatType = "html", template: str | None = None, **kwargs 

177) -> str: 

178 """color tokens assuming AOTP format 

179 

180 i.e: adjaceny list, origin, target, path 

181 

182 """ 

183 output: list[str] = [ 

184 " ".join( 

185 tokens_between( 

186 tokens, start_tok, end_tok, include_start=True, include_end=True 

187 ) 

188 ) 

189 for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS.keys() 

190 ] 

191 

192 colors: RGBArray = np.array( 

193 list(_MAZE_TOKENS_DEFAULT_COLORS.values()), dtype=np.uint8 

194 ) 

195 

196 return color_tokens_rgb( 

197 tokens=output, colors=colors, fmt=fmt, template=template, **kwargs 

198 ) 

199 

200 

201def display_html(html: str): 

202 display(HTML(html)) 

203 

204 

205def display_color_tokens_rgb( 

206 tokens: list[str], 

207 colors: RGBArray, 

208) -> None: 

209 html: str = color_tokens_rgb(tokens, colors, fmt="html") 

210 display_html(html) 

211 

212 

213def display_color_tokens_cmap( 

214 tokens: list[str], 

215 weights: Sequence[float], 

216 cmap: str | matplotlib.colors.Colormap = "Blues", 

217) -> None: 

218 html: str = color_tokens_cmap(tokens, weights, cmap) 

219 display_html(html) 

220 

221 

222def display_color_maze_tokens_AOTP( 

223 tokens: list[str], 

224) -> None: 

225 html: str = color_maze_tokens_AOTP(tokens) 

226 display_html(html)