Coverage for maze_dataset\plotting\plot_tokens.py: 0%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"`plot_colored_text` function to plot tokens on a matplotlib axis with colored backgrounds" 

2 

3from typing import Any, Sequence 

4 

5import matplotlib.pyplot as plt 

6import numpy as np 

7 

8 

9def plot_colored_text( 

10 tokens: Sequence[str], 

11 weights: Sequence[float], 

12 cmap: str | Any, # assume its a colormap if not a string 

13 ax: plt.Axes = None, 

14 width_scale: float = 0.023, 

15 width_offset: float = 0.005, 

16 height_offset: float = 0.1, 

17 rect_height: float = 0.7, 

18 token_height: float = 0.7, 

19 label_height: float = 0.3, 

20 word_gap: float = 0.01, 

21 fontsize: int = 12, 

22 fig_height: float = 0.7, 

23 fig_width_scale: float = 0.25, 

24 char_min: int = 4, 

25): 

26 "hacky function to plot tokens on a matplotlib axis with colored backgrounds" 

27 assert len(tokens) == len(weights), ( 

28 f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}" 

29 ) 

30 total_len_estimate: float = sum([max(len(tok), char_min) for tok in tokens]) 

31 # set up figure if needed 

32 if ax is None: 

33 fig, ax = plt.subplots( 

34 figsize=(total_len_estimate * fig_width_scale, fig_height) 

35 ) 

36 ax.axis("off") 

37 

38 # Normalize the weights to be between 0 and 1 

39 norm_weights: Sequence[float] = (weights - np.min(weights)) / ( 

40 np.max(weights) - np.min(weights) 

41 ) 

42 

43 # Create a colormap instance 

44 if isinstance(cmap, str): 

45 colormap = plt.get_cmap(cmap) 

46 else: 

47 colormap = cmap 

48 

49 x_pos: float = 0.0 

50 for i, (tok, weight, norm_wgt) in enumerate(zip(tokens, weights, norm_weights)): 

51 color = colormap(norm_wgt)[:3] 

52 

53 # Plot the background color 

54 rect_width = width_scale * max(len(tok), char_min) 

55 ax.add_patch( 

56 plt.Rectangle( 

57 (x_pos, height_offset), 

58 rect_width, 

59 height_offset + rect_height, 

60 fc=color, 

61 ec="none", 

62 ) 

63 ) 

64 

65 # Plot the token 

66 ax.text( 

67 x_pos + width_offset, 

68 token_height, 

69 tok, 

70 fontsize=fontsize, 

71 va="center", 

72 ha="left", 

73 ) 

74 

75 # Plot the weight below the token 

76 ax.text( 

77 x_pos + width_offset, 

78 label_height, 

79 f"{weight:.2f}", 

80 fontsize=fontsize, 

81 va="center", 

82 ha="left", 

83 ) 

84 

85 x_pos += rect_width + word_gap 

86 

87 return ax