Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2""" 

3Authors: Josef Perktold, Skipper Seabold, Denis A. Engemann 

4""" 

5from statsmodels.compat.python import iterkeys, lrange, iteritems 

6import numpy as np 

7 

8from statsmodels.graphics.plottools import rainbow 

9import statsmodels.graphics.utils as utils 

10 

11 

12def interaction_plot(x, trace, response, func=np.mean, ax=None, plottype='b', 

13 xlabel=None, ylabel=None, colors=None, markers=None, 

14 linestyles=None, legendloc='best', legendtitle=None, 

15 **kwargs): 

16 """ 

17 Interaction plot for factor level statistics. 

18 

19 Note. If categorial factors are supplied levels will be internally 

20 recoded to integers. This ensures matplotlib compatibility. Uses 

21 a DataFrame to calculate an `aggregate` statistic for each level of the 

22 factor or group given by `trace`. 

23 

24 Parameters 

25 ---------- 

26 x : array_like 

27 The `x` factor levels constitute the x-axis. If a `pandas.Series` is 

28 given its name will be used in `xlabel` if `xlabel` is None. 

29 trace : array_like 

30 The `trace` factor levels will be drawn as lines in the plot. 

31 If `trace` is a `pandas.Series` its name will be used as the 

32 `legendtitle` if `legendtitle` is None. 

33 response : array_like 

34 The reponse or dependent variable. If a `pandas.Series` is given 

35 its name will be used in `ylabel` if `ylabel` is None. 

36 func : function 

37 Anything accepted by `pandas.DataFrame.aggregate`. This is applied to 

38 the response variable grouped by the trace levels. 

39 ax : axes, optional 

40 Matplotlib axes instance 

41 plottype : str {'line', 'scatter', 'both'}, optional 

42 The type of plot to return. Can be 'l', 's', or 'b' 

43 xlabel : str, optional 

44 Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it 

45 will use the series names. 

46 ylabel : str, optional 

47 Label to use for `response`. Default is 'func of response'. If 

48 `response` is a `pandas.Series` it will use the series names. 

49 colors : list, optional 

50 If given, must have length == number of levels in trace. 

51 markers : list, optional 

52 If given, must have length == number of levels in trace 

53 linestyles : list, optional 

54 If given, must have length == number of levels in trace. 

55 legendloc : {None, str, int} 

56 Location passed to the legend command. 

57 legendtitle : {None, str} 

58 Title of the legend. 

59 **kwargs 

60 These will be passed to the plot command used either plot or scatter. 

61 If you want to control the overall plotting options, use kwargs. 

62 

63 Returns 

64 ------- 

65 Figure 

66 The figure given by `ax.figure` or a new instance. 

67 

68 Examples 

69 -------- 

70 >>> import numpy as np 

71 >>> np.random.seed(12345) 

72 >>> weight = np.random.randint(1,4,size=60) 

73 >>> duration = np.random.randint(1,3,size=60) 

74 >>> days = np.log(np.random.randint(1,30, size=60)) 

75 >>> fig = interaction_plot(weight, duration, days, 

76 ... colors=['red','blue'], markers=['D','^'], ms=10) 

77 >>> import matplotlib.pyplot as plt 

78 >>> plt.show() 

79 

80 .. plot:: 

81 

82 import numpy as np 

83 from statsmodels.graphics.factorplots import interaction_plot 

84 np.random.seed(12345) 

85 weight = np.random.randint(1,4,size=60) 

86 duration = np.random.randint(1,3,size=60) 

87 days = np.log(np.random.randint(1,30, size=60)) 

88 fig = interaction_plot(weight, duration, days, 

89 colors=['red','blue'], markers=['D','^'], ms=10) 

90 import matplotlib.pyplot as plt 

91 #plt.show() 

92 """ 

93 

94 from pandas import DataFrame 

95 fig, ax = utils.create_mpl_ax(ax) 

96 

97 response_name = ylabel or getattr(response, 'name', 'response') 

98 ylabel = '%s of %s' % (func.__name__, response_name) 

99 xlabel = xlabel or getattr(x, 'name', 'X') 

100 legendtitle = legendtitle or getattr(trace, 'name', 'Trace') 

101 

102 ax.set_ylabel(ylabel) 

103 ax.set_xlabel(xlabel) 

104 

105 x_values = x_levels = None 

106 if isinstance(x[0], str): 

107 x_levels = [l for l in np.unique(x)] 

108 x_values = lrange(len(x_levels)) 

109 x = _recode(x, dict(zip(x_levels, x_values))) 

110 

111 data = DataFrame(dict(x=x, trace=trace, response=response)) 

112 plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index() 

113 

114 # return data 

115 # check plot args 

116 n_trace = len(plot_data['trace'].unique()) 

117 

118 linestyles = ['-'] * n_trace if linestyles is None else linestyles 

119 markers = ['.'] * n_trace if markers is None else markers 

120 colors = rainbow(n_trace) if colors is None else colors 

121 

122 if len(linestyles) != n_trace: 

123 raise ValueError("Must be a linestyle for each trace level") 

124 if len(markers) != n_trace: 

125 raise ValueError("Must be a marker for each trace level") 

126 if len(colors) != n_trace: 

127 raise ValueError("Must be a color for each trace level") 

128 

129 if plottype == 'both' or plottype == 'b': 

130 for i, (values, group) in enumerate(plot_data.groupby(['trace'])): 

131 # trace label 

132 label = str(group['trace'].values[0]) 

133 ax.plot(group['x'], group['response'], color=colors[i], 

134 marker=markers[i], label=label, 

135 linestyle=linestyles[i], **kwargs) 

136 elif plottype == 'line' or plottype == 'l': 

137 for i, (values, group) in enumerate(plot_data.groupby(['trace'])): 

138 # trace label 

139 label = str(group['trace'].values[0]) 

140 ax.plot(group['x'], group['response'], color=colors[i], 

141 label=label, linestyle=linestyles[i], **kwargs) 

142 elif plottype == 'scatter' or plottype == 's': 

143 for i, (values, group) in enumerate(plot_data.groupby(['trace'])): 

144 # trace label 

145 label = str(group['trace'].values[0]) 

146 ax.scatter(group['x'], group['response'], color=colors[i], 

147 label=label, marker=markers[i], **kwargs) 

148 

149 else: 

150 raise ValueError("Plot type %s not understood" % plottype) 

151 ax.legend(loc=legendloc, title=legendtitle) 

152 ax.margins(.1) 

153 

154 if all([x_levels, x_values]): 

155 ax.set_xticks(x_values) 

156 ax.set_xticklabels(x_levels) 

157 return fig 

158 

159 

160def _recode(x, levels): 

161 """ Recode categorial data to int factor. 

162 

163 Parameters 

164 ---------- 

165 x : array_like 

166 array like object supporting with numpy array methods of categorially 

167 coded data. 

168 levels : dict 

169 mapping of labels to integer-codings 

170 

171 Returns 

172 ------- 

173 out : instance numpy.ndarray 

174 """ 

175 from pandas import Series 

176 name = None 

177 index = None 

178 

179 if isinstance(x, Series): 

180 name = x.name 

181 index = x.index 

182 x = x.values 

183 

184 if x.dtype.type not in [np.str_, np.object_]: 

185 raise ValueError('This is not a categorial factor.' 

186 ' Array of str type required.') 

187 

188 elif not isinstance(levels, dict): 

189 raise ValueError('This is not a valid value for levels.' 

190 ' Dict required.') 

191 

192 elif not (np.unique(x) == np.unique(list(iterkeys(levels)))).all(): 

193 raise ValueError('The levels do not match the array values.') 

194 

195 else: 

196 out = np.empty(x.shape[0], dtype=np.int) 

197 for level, coding in iteritems(levels): 

198 out[x == level] = coding 

199 

200 if name: 

201 out = Series(out, name=name, index=index) 

202 

203 return out