Coverage for maze_dataset\plotting\plot_tokens.py: 0%
22 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"`plot_colored_text` function to plot tokens on a matplotlib axis with colored backgrounds"
3from typing import Any, Sequence
5import matplotlib.pyplot as plt
6import numpy as np
9def plot_colored_text(
10 tokens: Sequence[str],
11 weights: Sequence[float],
12 cmap: str | Any, # assume its a colormap if not a string
13 ax: plt.Axes = None,
14 width_scale: float = 0.023,
15 width_offset: float = 0.005,
16 height_offset: float = 0.1,
17 rect_height: float = 0.7,
18 token_height: float = 0.7,
19 label_height: float = 0.3,
20 word_gap: float = 0.01,
21 fontsize: int = 12,
22 fig_height: float = 0.7,
23 fig_width_scale: float = 0.25,
24 char_min: int = 4,
25):
26 "hacky function to plot tokens on a matplotlib axis with colored backgrounds"
27 assert len(tokens) == len(weights), (
28 f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}"
29 )
30 total_len_estimate: float = sum([max(len(tok), char_min) for tok in tokens])
31 # set up figure if needed
32 if ax is None:
33 fig, ax = plt.subplots(
34 figsize=(total_len_estimate * fig_width_scale, fig_height)
35 )
36 ax.axis("off")
38 # Normalize the weights to be between 0 and 1
39 norm_weights: Sequence[float] = (weights - np.min(weights)) / (
40 np.max(weights) - np.min(weights)
41 )
43 # Create a colormap instance
44 if isinstance(cmap, str):
45 colormap = plt.get_cmap(cmap)
46 else:
47 colormap = cmap
49 x_pos: float = 0.0
50 for i, (tok, weight, norm_wgt) in enumerate(zip(tokens, weights, norm_weights)):
51 color = colormap(norm_wgt)[:3]
53 # Plot the background color
54 rect_width = width_scale * max(len(tok), char_min)
55 ax.add_patch(
56 plt.Rectangle(
57 (x_pos, height_offset),
58 rect_width,
59 height_offset + rect_height,
60 fc=color,
61 ec="none",
62 )
63 )
65 # Plot the token
66 ax.text(
67 x_pos + width_offset,
68 token_height,
69 tok,
70 fontsize=fontsize,
71 va="center",
72 ha="left",
73 )
75 # Plot the weight below the token
76 ax.text(
77 x_pos + width_offset,
78 label_height,
79 f"{weight:.2f}",
80 fontsize=fontsize,
81 va="center",
82 ha="left",
83 )
85 x_pos += rect_width + word_gap
87 return ax