Package trunk :: Package BIP :: Package Bayes :: Module PlotMeld
[hide private]

Source Code for Module trunk.BIP.Bayes.PlotMeld

  1  # To change this template, choose Tools | Templates 
  2  # and open the template in the editor. 
  3  """ 
  4  Module with specialized plotting functions for the Melding results 
  5  """ 
  6   
  7  __author__="fccoelho" 
  8  __date__ ="$06/01/2010 11:24:20$" 
  9  __docformat__ = "restructuredtext en" 
 10   
 11  from itertools import cycle 
 12  from scipy import stats 
 13  import glob 
 14  import datetime 
 15  from numpy import * 
 16  import matplotlib.pyplot as P 
 17  from matplotlib.dates import date2num 
 18  from scipy.stats import gaussian_kde 
 19  import pdb 
 20   
 21   
 22   
23 -def plot_series(tim,obs, series, names=[],title='series'):
24 c = cycle(['b','g','r','c','m','y','k']) 25 ser2={} 26 for n in series[0].dtype.names: 27 ser2[n] = concatenate([s[n] for s in series],axis=1) 28 if not names: 29 names = series[0].dtype.names 30 for i,n in enumerate(names): 31 #print n 32 #P.subplot(l,1,i) 33 co = c.next() 34 P.plot(tim,[stats.scoreatpercentile(t,5) for t in ser2[n].T],co+'-.') 35 P.plot(tim,[stats.scoreatpercentile(t,95) for t in ser2[n].T],co+'-.') 36 P.plot(tim,[stats.scoreatpercentile(t,50) for t in ser2[n].T],co+'-',lw=2,label=n) 37 #P.plot(tim,y[:,list(series.dtype.names).index(n)],'*',label=n+' obs.') 38 if n in obs: 39 P.plot(tim,obs[n],'*',label=n+' obs.') 40 P.savefig(title+'.png')
41
42 -def plot_pred(tim,series,y, fig,names=[],title='series'):
43 c = cycle(['b','g','r','c','m','y','k']) 44 if not names: 45 names = series.dtype.names 46 l = len(names) 47 ax = fig.add_subplot(111) 48 for i,n in enumerate(names): 49 co = c.next() 50 for b in [2.5,25,50,75,97.5]: 51 if b == 50: 52 st = 'k-' 53 else: 54 st = '%s:'%co 55 try: 56 lower = [stats.scoreatpercentile(t,0+b) for t in series[n].T] 57 upper =[stats.scoreatpercentile(t,100-b) for t in series[n].T] 58 ax.plot(tim,lower,st,tim,upper,st,label=n) 59 ax.fill_between(tim,lower,upper,facecolor=co,alpha=0.12)# 60 except: 61 pass 62 if n in y: 63 #print y[n].shape, tim.shape 64 try: #in the last window we run out of date points 65 P.plot(tim,y[n],'^',label=n+' obs.') 66 except: 67 pass 68 P.savefig(title+'.png')
69
70 -def pred_new_cases(obs,series,weeks,names=[], title='Total new cases per window: predicted vs observed' ,ws=7):
71 """ 72 Predicted total new cases in a window vs oserved. 73 """ 74 fig =P.gcf() 75 P.title(title) 76 if not names: 77 names = series[0].dtype.names 78 ax = P.gca()#fig.add_subplot(111) 79 c = cycle(['b','g','r','c','m','y','k']) 80 if 'time' in obs: #setting the xlabel 81 x = date2num([obs['time'][ws*i] for i in range(1, weeks)]) 82 ax.xaxis_date() 83 else: 84 x = arange(1, weeks) 85 sc= 1 if len(series) ==1 else 5 86 W = min(0.5*max(len(x),1.0),0.5)*sc 87 for n in names: 88 if n in obs: 89 co = c.next() 90 print len(x), len([mean(sum(s[n],axis=1)) for s in series]), type(x) 91 ax.plot([x[7]]+x.tolist(), [mean(sum(s[n],axis=1)) for s in series],'%s^'%co, label="Mean pred. %s"%n) 92 ax.plot(x,[nansum(obs[n][(w+1)*ws:(w+1)*ws+ws]) for w in range(weeks-1)],'%s-o'%co, label="obs. Prev") 93 ax.boxplot([nansum(s[n],axis=1) for s in series] ,positions = x, widths=W,notch=1,vert=1) 94 #P.xlabel('windows') 95 #ax.legend(loc=0) 96 if 'time' in obs: 97 fig.autofmt_xdate()
98
99 -def plot_series2(tim,obs,series,names=[],title='Simulated vs Observed series',wl=7,lag=False):
100 ser2={} 101 for n in series[0].dtype.names: 102 ser2[n] = concatenate([s[n] for s in series],axis=1) 103 ls = ser2[n].shape[1] 104 tim = tim[:ls] 105 #print type (series)#.I.shape 106 fig =P.gcf() 107 if not names: 108 names = series[0].dtype.names 109 c = cycle(['b','g','r','c','m','y','k']) 110 if isinstance(tim[0], datetime.date): 111 lag = datetime.timedelta(int(lag)*wl) 112 else: 113 lag = int(lag)*wl 114 for i, n in enumerate(names): 115 ax = fig.add_subplot(len(names), 1, i+1) 116 ax.grid(True) 117 if isinstance(tim[0], datetime.date): 118 ax.xaxis_date() 119 co = c.next() 120 if n in obs: 121 ax.plot(tim,obs[n][:len(tim)],'o', label=r"$Observed\; %s$"%n) 122 #pdb.set_trace() 123 ax.plot(array(tim)+lag,median(ser2[n],axis=0),'k-', label=r"$median\; %s$"%n) 124 ax.plot(array(tim)+lag,mean(ser2[n],axis=0),'k--', label=r"$mean\; %s$"%n) 125 lower = [stats.scoreatpercentile(t,2.5) for t in ser2[n].T] 126 upper =[stats.scoreatpercentile(t,97.5) for t in ser2[n].T] 127 if len(series)>1: #in the case of iterative simulations 128 dif = (array(upper)-array(lower)) 129 dif = dif/max(dif)*10 130 pe, va = peakdet(dif, 1) 131 xp = [0]+ pe[:, 0].tolist()+[len(lower)-1] 132 lower = interp(range(len(lower)), xp, array(lower)[xp]) # valley-to-valley interpolated band 133 upper = interp(range(len(upper)), xp, array(upper)[xp])#peak-to-peak interpolated band 134 ax.fill_between(array(tim)+lag,lower,upper,facecolor=co,alpha=0.2) 135 #ax.fill_between(array(tim)+lag,lower,upper,facecolor='k',alpha=0.1) 136 if i < (len(names)-1):ax.xaxis.set_ticklabels([]) 137 ax.legend() 138 if i == 0: 139 ax.set_title(title) 140 #ax.xaxis.set_visible(True) 141 #P.title(title) 142 P.xlabel('days') 143 if isinstance(tim[0], datetime.date): 144 fig.autofmt_xdate()
145 146
147 -def plot_par_series(tim,ptlist):
148 P.figure() 149 P.title('Parameters temporal variation') 150 sq = sqrt(len(ptlist[0].dtype.names)) 151 r= floor(sq);c=ceil(sq) 152 for i,n in enumerate(ptlist[0].dtype.names): 153 P.subplot(r,c,i+1) 154 P.boxplot([s[n] for s in ptlist],notch=1,positions=tim,vert=1) 155 #P.errorbar(tim,[median(t[n]) for t in ptlist],yerr=[std(t[n]) for t in ptlist],label=n) 156 P.ylabel(n) 157 P.xlabel('Windows')
158
159 -def plot_par_violin(tim,ptlist, priors={}, bp=True):
160 fig = P.figure() 161 #P.title('Parameters temporal variation') 162 sq = sqrt(len(ptlist[0].dtype.names)) 163 ad = 1 if sq%1 >0.5 else 0 164 r= floor(sq)+ad;c=ceil(sq) 165 if len(ptlist[0].dtype.names) == 3: 166 r = 3; c = 1 167 if priors: 168 if isinstance(tim[0], datetime.date): 169 pdate = tim[0]-datetime.timedelta(1) if len(tim)==1 else tim[0]-(tim[1]-tim[0]) 170 tim =[pdate]+tim.tolist() 171 else: 172 if len(tim)==1: 173 tim = [-1, 0] 174 else: 175 tim = [tim[0]-(tim[1]-tim[0])]+tim 176 for i,n in enumerate(ptlist[0].dtype.names): 177 ax = fig.add_subplot(r,c,i+1) 178 ax.grid(True) 179 violin_plot(ax,[priors[n]]+[s[n] for s in ptlist],tim,bp, True) 180 P.ylabel(n) 181 #P.xlabel('Windows') 182 if isinstance(tim[0], datetime.date): 183 fig.autofmt_xdate()
184
185 -def violin_plot(ax,data,positions,bp=False, prior = False):
186 ''' 187 Create violin plots on an axis 188 189 :Parameters: 190 - `ax`: A subplot object 191 - `data`: A list of data sets to plot 192 - `positions`: x values to position the violins. Can be datetime.date objects. 193 - `bp`: Whether to plot the boxplot on top. 194 - `prior`: whether the first element of data is a Prior distribution. 195 ''' 196 sc = 1 197 dist = len(positions) 198 if isinstance(positions[0], datetime.date): 199 ax.xaxis_date() 200 positions = date2num(positions) 201 sc = 5 if (dist>2 ) else 1 202 #print sc 203 w = min(0.5*max(dist,1.0),0.5)*sc 204 i = 0 205 206 for d,p in zip(data, positions): 207 if prior and i == 0: 208 color = 'g' 209 else: 210 color = 'y' 211 k = gaussian_kde(d) #calculates the kernel density 212 m = k.dataset.min() #lower bound of violin 213 M = k.dataset.max() #upper bound of violin 214 x = arange(m,M,(M-m)/100.) # support for violin 215 v = k.evaluate(x) #violin profile (density curve) 216 v = v/v.max()*w #scaling the violin to the available space 217 ax.fill_betweenx(x,p,v+p,facecolor=color,alpha=0.3) 218 ax.fill_betweenx(x,p,-v+p,facecolor=color,alpha=0.3) 219 i+=1 220 if bp: 221 ax.boxplot(data,notch=1,positions=positions,widths=w, vert=1)
222
223 -def peakdet(v, delta, x = None):
224 """ 225 Converted from MATLAB script at http://billauer.co.il/peakdet.html 226 Currently returns two lists of tuples, but maybe arrays would be better 227 function [maxtab, mintab]=peakdet(v, delta, x) 228 %PEAKDET Detect peaks in a vector 229 % [MAXTAB, MINTAB] = PEAKDET(V, DELTA) finds the local 230 % maxima and minima ("peaks") in the vector V. 231 % MAXTAB and MINTAB consists of two columns. Column 1 232 % contains indices in V, and column 2 the found values. 233 % 234 % With [MAXTAB, MINTAB] = PEAKDET(V, DELTA, X) the indices 235 % in MAXTAB and MINTAB are replaced with the corresponding 236 % X-values. 237 % 238 % A point is considered a maximum peak if it has the maximal 239 % value, and was preceded (to the left) by a value lower by 240 % DELTA. 241 % Eli Billauer, 3.4.05 (Explicitly not copyrighted). 242 % This function is released to the public domain; Any use is allowed. 243 """ 244 maxtab = [] 245 mintab = [] 246 247 if x is None: 248 x = arange(len(v)) 249 250 v = asarray(v) 251 252 if len(v) != len(x): 253 sys.exit('Input vectors v and x must have same length') 254 255 if not isscalar(delta): 256 sys.exit('Input argument delta must be a scalar') 257 258 if delta <= 0: 259 sys.exit('Input argument delta must be positive') 260 261 mn, mx = Inf, -Inf 262 mnpos, mxpos = NaN, NaN 263 264 lookformax = True 265 266 for i in arange(len(v)): 267 this = v[i] 268 if this > mx: 269 mx = this 270 mxpos = x[i] 271 if this < mn: 272 mn = this 273 mnpos = x[i] 274 275 if lookformax: 276 if this < mx-delta: 277 maxtab.append((mxpos, mx)) 278 mn = this 279 mnpos = x[i] 280 lookformax = False 281 else: 282 if this > mn+delta: 283 mintab.append((mnpos, mn)) 284 mx = this 285 mxpos = x[i] 286 lookformax = True 287 288 return array(maxtab), array(mintab)
289 290 if __name__ == "__main__": 291 print "Hello World"; 292