Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Helper functions for graphics with Matplotlib.""" 

2from statsmodels.compat.python import lrange 

3 

4__all__ = ['create_mpl_ax', 'create_mpl_fig'] 

5 

6 

7def _import_mpl(): 

8 """This function is not needed outside this utils module.""" 

9 try: 

10 import matplotlib.pyplot as plt 

11 except: 

12 raise ImportError("Matplotlib is not found.") 

13 

14 return plt 

15 

16 

17def create_mpl_ax(ax=None): 

18 """Helper function for when a single plot axis is needed. 

19 

20 Parameters 

21 ---------- 

22 ax : AxesSubplot, optional 

23 If given, this subplot is used to plot in instead of a new figure being 

24 created. 

25 

26 Returns 

27 ------- 

28 fig : Figure 

29 If `ax` is None, the created figure. Otherwise the figure to which 

30 `ax` is connected. 

31 ax : AxesSubplot 

32 The created axis if `ax` is None, otherwise the axis that was passed 

33 in. 

34 

35 Notes 

36 ----- 

37 This function imports `matplotlib.pyplot`, which should only be done to 

38 create (a) figure(s) with ``plt.figure``. All other functionality exposed 

39 by the pyplot module can and should be imported directly from its 

40 Matplotlib module. 

41 

42 See Also 

43 -------- 

44 create_mpl_fig 

45 

46 Examples 

47 -------- 

48 A plotting function has a keyword ``ax=None``. Then calls: 

49 

50 >>> from statsmodels.graphics import utils 

51 >>> fig, ax = utils.create_mpl_ax(ax) 

52 """ 

53 if ax is None: 

54 plt = _import_mpl() 

55 fig = plt.figure() 

56 ax = fig.add_subplot(111) 

57 else: 

58 fig = ax.figure 

59 

60 return fig, ax 

61 

62 

63def create_mpl_fig(fig=None, figsize=None): 

64 """Helper function for when multiple plot axes are needed. 

65 

66 Those axes should be created in the functions they are used in, with 

67 ``fig.add_subplot()``. 

68 

69 Parameters 

70 ---------- 

71 fig : Figure, optional 

72 If given, this figure is simply returned. Otherwise a new figure is 

73 created. 

74 

75 Returns 

76 ------- 

77 Figure 

78 If `fig` is None, the created figure. Otherwise the input `fig` is 

79 returned. 

80 

81 See Also 

82 -------- 

83 create_mpl_ax 

84 """ 

85 if fig is None: 

86 plt = _import_mpl() 

87 fig = plt.figure(figsize=figsize) 

88 

89 return fig 

90 

91 

92def maybe_name_or_idx(idx, model): 

93 """ 

94 Give a name or an integer and return the name and integer location of the 

95 column in a design matrix. 

96 """ 

97 if idx is None: 

98 idx = lrange(model.exog.shape[1]) 

99 if isinstance(idx, int): 

100 exog_name = model.exog_names[idx] 

101 exog_idx = idx 

102 # anticipate index as list and recurse 

103 elif isinstance(idx, (tuple, list)): 

104 exog_name = [] 

105 exog_idx = [] 

106 for item in idx: 

107 exog_name_item, exog_idx_item = maybe_name_or_idx(item, model) 

108 exog_name.append(exog_name_item) 

109 exog_idx.append(exog_idx_item) 

110 else: # assume we've got a string variable 

111 exog_name = idx 

112 exog_idx = model.exog_names.index(idx) 

113 

114 return exog_name, exog_idx 

115 

116 

117def get_data_names(series_or_dataframe): 

118 """ 

119 Input can be an array or pandas-like. Will handle 1d array-like but not 

120 2d. Returns a str for 1d data or a list of strings for 2d data. 

121 """ 

122 names = getattr(series_or_dataframe, 'name', None) 

123 if not names: 

124 names = getattr(series_or_dataframe, 'columns', None) 

125 if not names: 

126 shape = getattr(series_or_dataframe, 'shape', [1]) 

127 nvars = 1 if len(shape) == 1 else series_or_dataframe.shape[1] 

128 names = ["X%d" for _ in range(nvars)] 

129 if nvars == 1: 

130 names = names[0] 

131 else: 

132 names = names.tolist() 

133 return names 

134 

135 

136def annotate_axes(index, labels, points, offset_points, size, ax, **kwargs): 

137 """ 

138 Annotate Axes with labels, points, offset_points according to the 

139 given index. 

140 """ 

141 for i in index: 

142 label = labels[i] 

143 point = points[i] 

144 offset = offset_points[i] 

145 ax.annotate(label, point, xytext=offset, textcoords="offset points", 

146 size=size, **kwargs) 

147 return ax