Coverage for C:\src\imod-python\imod\visualize\waterbalance.py: 95%

61 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 14:15 +0200

1import itertools 

2 

3import matplotlib.pyplot as plt 

4import numpy as np 

5import pandas as pd 

6 

7 

8def _draw_bars(ax, x, df, labels, barwidth, colors): 

9 ndates, _ = df.shape 

10 bottoms = np.hstack([np.zeros((ndates, 1)), df.cumsum(axis=1).values]).T[:-1] 

11 heights = df.values.T 

12 if colors is None: 

13 for label, bottom, height in zip(labels, bottoms, heights): 

14 ax.bar( 

15 x, 

16 bottom=bottom, 

17 height=height, 

18 width=barwidth, 

19 edgecolor="k", 

20 label=label, 

21 ) 

22 else: 

23 for label, bottom, height, color in zip(labels, bottoms, heights, colors): 

24 ax.bar( 

25 x, 

26 bottom=bottom, 

27 height=height, 

28 width=barwidth, 

29 edgecolor="k", 

30 label=label, 

31 color=color, 

32 ) 

33 

34 

35def waterbalance_barchart( 

36 df, 

37 inflows, 

38 outflows, 

39 datecolumn=None, 

40 format="%Y-%m-%d", 

41 ax=None, 

42 unit=None, 

43 colors=None, 

44): 

45 """ 

46 Parameters 

47 ---------- 

48 df : pandas.DataFrame 

49 The dataframe containing the water balance data. 

50 inflows : listlike of str 

51 outflows : listlike of str 

52 datecolumn : str, optional 

53 format : str, optional, 

54 ax : matplotlib.Axes, optional 

55 unit : str, optional 

56 colors : listlike of strings or tuples 

57 

58 Returns 

59 ------- 

60 ax : matplotlib.Axes 

61 

62 Examples 

63 -------- 

64 

65 >>> fig, ax = plt.subplots() 

66 >>> imod.visualize.waterbalance_barchart( 

67 >>> ax=ax, 

68 >>> df=df, 

69 >>> inflows=["Rainfall", "River upstream"], 

70 >>> outflows=["Evapotranspiration", "Discharge to Sea"], 

71 >>> datecolumn="Time", 

72 >>> format="%Y-%m-%d", 

73 >>> unit="m3/d", 

74 >>> colors=["#ca0020", "#f4a582", "#92c5de", "#0571b0"], 

75 >>> ) 

76 >>> fig.savefig("Waterbalance.png", dpi=300, bbox_inches="tight") 

77 

78 """ 

79 # Do some checks 

80 if not isinstance(df, pd.DataFrame): 

81 raise TypeError("df should be a pandas.DataFrame") 

82 if datecolumn is not None: 

83 if datecolumn not in df.columns: 

84 raise ValueError(f"datecolumn {datecolumn} not in df") 

85 for column in itertools.chain(inflows, outflows): 

86 if column not in df: 

87 raise ValueError(f"{column} not in df") 

88 if colors is not None: 

89 ncolors = len(colors) 

90 nflows = len(inflows + outflows) 

91 if ncolors < nflows: 

92 raise ValueError( 

93 f"Not enough colors: Number of flows is {nflows}, while number of colors is {ncolors}" 

94 ) 

95 # Deal with colors, takes both dict and list 

96 if isinstance(colors, dict): 

97 incolors = [colors[k] for k in inflows] 

98 outcolors = [colors[k] for k in outflows] 

99 elif isinstance(colors, (tuple, list)): 

100 incolors = colors[: len(inflows)] 

101 outcolors = colors[len(inflows) :] 

102 else: 

103 incolors = None 

104 outcolors = None 

105 

106 # Determine x position 

107 ndates, _ = df.shape 

108 barwidth = 1.0 

109 r1 = np.arange(0.0, ndates * barwidth * 3, barwidth * 3) 

110 r2 = np.array([x + barwidth for x in r1]) 

111 r_between = 0.5 * (r1 + r2) 

112 

113 # Grab ax if not provided directly 

114 if ax is None: 

115 ax = plt.gca() 

116 

117 # Draw inflows 

118 _draw_bars( 

119 ax=ax, x=r1, df=df[inflows], labels=inflows, barwidth=barwidth, colors=incolors 

120 ) 

121 # Draw outflows 

122 _draw_bars( 

123 ax=ax, 

124 x=r2, 

125 df=df[outflows], 

126 labels=outflows, 

127 barwidth=barwidth, 

128 colors=outcolors, 

129 ) 

130 

131 # Place xticks 

132 xticks_location = list(itertools.chain(*zip(r1, r_between, r2))) 

133 # Collect the labels, and format them as desired 

134 # TODO: might not work for all dateformats? 

135 xticks_labels = [] 

136 if datecolumn is None: 

137 dates = df.index 

138 else: 

139 dates = df[datecolumn] 

140 for date in dates: 

141 # Place the date labels two lines (two \n) below the minor labels ("in", "out") 

142 xticks_labels.extend(["in", f"\n\n{date.strftime(format)}", "out"]) 

143 

144 # Adjust the ticks. Lengthen the major ticks, so they extend down to the dates 

145 ax.tick_params(axis="x", which="major", bottom=False, top=False, labelbottom=True) 

146 ax.tick_params( 

147 axis="x", 

148 which="minor", 

149 bottom=True, 

150 top=False, 

151 labelbottom=False, 

152 length=barwidth * 45, 

153 ) 

154 ax.xaxis.set_ticks(xticks_location) 

155 ax.xaxis.set_ticklabels(xticks_labels) 

156 xticks_location_minor = r1[1:] - barwidth 

157 ax.xaxis.set_ticks(xticks_location_minor, minor=True) 

158 

159 # Create a legend on the right side of the chart 

160 ax.legend( 

161 loc="upper left", 

162 bbox_to_anchor=(1.03, 1.0), 

163 ncol=2, 

164 borderaxespad=0, 

165 frameon=True, 

166 ) 

167 

168 # Set a unit on the y-axis 

169 if unit is not None: 

170 ax.yaxis.set_label(unit) 

171 

172 return ax