Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2utility functions for asteval 

3 

4 Matthew Newville <newville@cars.uchicago.edu>, 

5 The University of Chicago 

6""" 

7import io 

8import re 

9import ast 

10import math 

11import numbers 

12from sys import exc_info 

13from tokenize import (tokenize as generate_tokens, 

14 ENCODING as tk_ENCODING, 

15 NAME as tk_NAME) 

16 

17HAS_NUMPY = False 

18numpy = None 

19try: 

20 import numpy 

21 ndarr = numpy.ndarray 

22 HAS_NUMPY = True 

23except ImportError: 

24 pass 

25 

26MAX_EXPONENT = 10000 

27MAX_STR_LEN = 2 << 17 # 256KiB 

28MAX_SHIFT = 1000 

29MAX_OPEN_BUFFER = 2 << 17 

30 

31RESERVED_WORDS = ('and', 'as', 'assert', 'break', 'class', 'continue', 

32 'def', 'del', 'elif', 'else', 'except', 'exec', 

33 'finally', 'for', 'from', 'global', 'if', 'import', 

34 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 

35 'raise', 'return', 'try', 'while', 'with', 'True', 

36 'False', 'None', 'eval', 'execfile', '__import__', 

37 '__package__') 

38 

39NAME_MATCH = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$").match 

40 

41UNSAFE_ATTRS = ('__subclasses__', '__bases__', '__globals__', '__code__', 

42 '__closure__', '__func__', '__self__', '__module__', 

43 '__dict__', '__class__', '__call__', '__get__', 

44 '__getattribute__', '__subclasshook__', '__new__', 

45 '__init__', 'func_globals', 'func_code', 'func_closure', 

46 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', 

47 '__asteval__', 'f_locals', '__mro__') 

48 

49# inherit these from python's __builtins__ 

50FROM_PY = ('ArithmeticError', 'AssertionError', 'AttributeError', 

51 'BaseException', 'BufferError', 'BytesWarning', 

52 'DeprecationWarning', 'EOFError', 'EnvironmentError', 

53 'Exception', 'False', 'FloatingPointError', 'GeneratorExit', 

54 'IOError', 'ImportError', 'ImportWarning', 'IndentationError', 

55 'IndexError', 'KeyError', 'KeyboardInterrupt', 'LookupError', 

56 'MemoryError', 'NameError', 'None', 

57 'NotImplementedError', 'OSError', 'OverflowError', 

58 'ReferenceError', 'RuntimeError', 'RuntimeWarning', 

59 'StopIteration', 'SyntaxError', 'SyntaxWarning', 'SystemError', 

60 'SystemExit', 'True', 'TypeError', 'UnboundLocalError', 

61 'UnicodeDecodeError', 'UnicodeEncodeError', 'UnicodeError', 

62 'UnicodeTranslateError', 'UnicodeWarning', 'ValueError', 

63 'Warning', 'ZeroDivisionError', 'abs', 'all', 'any', 'bin', 

64 'bool', 'bytearray', 'bytes', 'chr', 'complex', 'dict', 'dir', 

65 'divmod', 'enumerate', 'filter', 'float', 'format', 'frozenset', 

66 'hash', 'hex', 'id', 'int', 'isinstance', 'len', 'list', 'map', 

67 'max', 'min', 'oct', 'ord', 'pow', 'range', 'repr', 

68 'reversed', 'round', 'set', 'slice', 'sorted', 'str', 'sum', 

69 'tuple', 'zip') 

70 

71# inherit these from python's math 

72FROM_MATH = ('acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 

73 'ceil', 'copysign', 'cos', 'cosh', 'degrees', 'e', 'exp', 

74 'fabs', 'factorial', 'floor', 'fmod', 'frexp', 'fsum', 

75 'hypot', 'isinf', 'isnan', 'ldexp', 'log', 'log10', 'log1p', 

76 'modf', 'pi', 'pow', 'radians', 'sin', 'sinh', 'sqrt', 'tan', 

77 'tanh', 'trunc') 

78 

79FROM_NUMPY = ('Inf', 'NAN', 'abs', 'add', 'alen', 'all', 'amax', 'amin', 

80 'angle', 'any', 'append', 'arange', 'arccos', 'arccosh', 

81 'arcsin', 'arcsinh', 'arctan', 'arctan2', 'arctanh', 

82 'argmax', 'argmin', 'argsort', 'argwhere', 'around', 'array', 

83 'array2string', 'asanyarray', 'asarray', 'asarray_chkfinite', 

84 'ascontiguousarray', 'asfarray', 'asfortranarray', 

85 'asmatrix', 'asscalar', 'atleast_1d', 'atleast_2d', 

86 'atleast_3d', 'average', 'bartlett', 'base_repr', 

87 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', 

88 'blackman', 'bool', 'broadcast', 'broadcast_arrays', 'byte', 

89 'c_', 'cdouble', 'ceil', 'cfloat', 'chararray', 'choose', 

90 'clip', 'clongdouble', 'clongfloat', 'column_stack', 

91 'common_type', 'complex', 'complex128', 'complex64', 

92 'complex_', 'complexfloating', 'compress', 'concatenate', 

93 'conjugate', 'convolve', 'copy', 'copysign', 'corrcoef', 

94 'correlate', 'cos', 'cosh', 'cov', 'cross', 'csingle', 

95 'cumprod', 'cumsum', 'datetime_data', 'deg2rad', 'degrees', 

96 'delete', 'diag', 'diag_indices', 'diag_indices_from', 

97 'diagflat', 'diagonal', 'diff', 'digitize', 'divide', 'dot', 

98 'double', 'dsplit', 'dstack', 'dtype', 'e', 'ediff1d', 

99 'empty', 'empty_like', 'equal', 'exp', 'exp2', 'expand_dims', 

100 'expm1', 'extract', 'eye', 'fabs', 'fill_diagonal', 'finfo', 

101 'fix', 'flatiter', 'flatnonzero', 'fliplr', 'flipud', 

102 'float', 'float32', 'float64', 'float_', 'floating', 'floor', 

103 'floor_divide', 'fmax', 'fmin', 'fmod', 'format_parser', 

104 'frexp', 'frombuffer', 'fromfile', 'fromfunction', 

105 'fromiter', 'frompyfunc', 'fromregex', 'fromstring', 'fv', 

106 'genfromtxt', 'getbufsize', 'geterr', 'gradient', 'greater', 

107 'greater_equal', 'hamming', 'hanning', 'histogram', 

108 'histogram2d', 'histogramdd', 'hsplit', 'hstack', 'hypot', 

109 'i0', 'identity', 'iinfo', 'imag', 'in1d', 'index_exp', 

110 'indices', 'inexact', 'inf', 'info', 'infty', 'inner', 

111 'insert', 'int', 'int0', 'int16', 'int32', 'int64', 'int8', 

112 'int_', 'int_asbuffer', 'intc', 'integer', 'interp', 

113 'intersect1d', 'intp', 'invert', 'ipmt', 'irr', 'iscomplex', 

114 'iscomplexobj', 'isfinite', 'isfortran', 'isinf', 'isnan', 

115 'isneginf', 'isposinf', 'isreal', 'isrealobj', 'isscalar', 

116 'issctype', 'iterable', 'ix_', 'kaiser', 'kron', 'ldexp', 

117 'left_shift', 'less', 'less_equal', 'linspace', 

118 'little_endian', 'load', 'loads', 'loadtxt', 'log', 'log10', 

119 'log1p', 'log2', 'logaddexp', 'logaddexp2', 'logical_and', 

120 'logical_not', 'logical_or', 'logical_xor', 'logspace', 

121 'long', 'longcomplex', 'longdouble', 'longfloat', 'longlong', 

122 'mafromtxt', 'mask_indices', 'mat', 'matrix', 

123 'maximum', 'maximum_sctype', 'may_share_memory', 'mean', 

124 'median', 'memmap', 'meshgrid', 'mgrid', 'minimum', 

125 'mintypecode', 'mirr', 'mod', 'modf', 'msort', 'multiply', 

126 'nan', 'nan_to_num', 'nanargmax', 'nanargmin', 'nanmax', 

127 'nanmin', 'nansum', 'ndarray', 'ndenumerate', 'ndfromtxt', 

128 'ndim', 'ndindex', 'negative', 'newaxis', 'nextafter', 

129 'nonzero', 'not_equal', 'nper', 'npv', 'number', 

130 'obj2sctype', 'ogrid', 'ones', 'ones_like', 'outer', 

131 'packbits', 'percentile', 'pi', 'piecewise', 'place', 'pmt', 

132 'poly', 'poly1d', 'polyadd', 'polyder', 'polydiv', 'polyfit', 

133 'polyint', 'polymul', 'polysub', 'polyval', 'power', 'ppmt', 

134 'prod', 'product', 'ptp', 'put', 'putmask', 'pv', 'r_', 

135 'rad2deg', 'radians', 'rank', 'rate', 'ravel', 'real', 

136 'real_if_close', 'reciprocal', 'record', 'remainder', 

137 'repeat', 'reshape', 'resize', 'restoredot', 'right_shift', 

138 'rint', 'roll', 'rollaxis', 'roots', 'rot90', 'round', 

139 'round_', 'row_stack', 's_', 'sctype2char', 'searchsorted', 

140 'select', 'setbufsize', 'setdiff1d', 'seterr', 'setxor1d', 

141 'shape', 'short', 'sign', 'signbit', 'signedinteger', 'sin', 

142 'sinc', 'single', 'singlecomplex', 'sinh', 'size', 

143 'sometrue', 'sort', 'sort_complex', 'spacing', 'split', 

144 'sqrt', 'square', 'squeeze', 'std', 'str', 'str_', 

145 'subtract', 'sum', 'swapaxes', 'take', 'tan', 'tanh', 

146 'tensordot', 'tile', 'trace', 'transpose', 'trapz', 'tri', 

147 'tril', 'tril_indices', 'tril_indices_from', 'trim_zeros', 

148 'triu', 'triu_indices', 'triu_indices_from', 'true_divide', 

149 'trunc', 'ubyte', 'uint', 'uint0', 'uint16', 'uint32', 

150 'uint64', 'uint8', 'uintc', 'uintp', 'ulonglong', 'union1d', 

151 'unique', 'unravel_index', 'unsignedinteger', 'unwrap', 

152 'ushort', 'vander', 'var', 'vdot', 'vectorize', 'vsplit', 

153 'vstack', 'where', 'who', 'zeros', 'zeros_like', 

154 'fft', 'linalg', 'polynomial', 'random') 

155 

156 

157NUMPY_RENAMES = {'ln': 'log', 'asin': 'arcsin', 'acos': 'arccos', 

158 'atan': 'arctan', 'atan2': 'arctan2', 'atanh': 

159 'arctanh', 'acosh': 'arccosh', 'asinh': 'arcsinh'} 

160 

161 

162def _open(filename, mode='r', buffering=0): 

163 """read only version of open()""" 

164 if mode not in ('r', 'rb', 'rU'): 

165 raise RuntimeError("Invalid open file mode, must be 'r', 'rb', or 'rU'") 

166 if buffering > MAX_OPEN_BUFFER: 

167 raise RuntimeError("Invalid buffering value, max buffer size is {}".format(MAX_OPEN_BUFFER)) 

168 return open(filename, mode, buffering) 

169 

170def _type(obj, *varargs, **varkws): 

171 """type that prevents varargs and varkws""" 

172 return type(obj).__name__ 

173 

174 

175LOCALFUNCS = {'open': _open, 'type': _type} 

176 

177 

178# Safe versions of functions to prevent denial of service issues 

179 

180def safe_pow(base, exp): 

181 """safe version of pow""" 

182 if isinstance(exp, numbers.Number): 

183 if exp > MAX_EXPONENT: 

184 raise RuntimeError("Invalid exponent, max exponent is {}".format(MAX_EXPONENT)) 

185 elif HAS_NUMPY: 

186 if isinstance(exp, numpy.ndarray): 

187 if numpy.nanmax(exp) > MAX_EXPONENT: 

188 raise RuntimeError("Invalid exponent, max exponent is {}".format(MAX_EXPONENT)) 

189 return base ** exp 

190 

191 

192def safe_mult(a, b): 

193 """safe version of multiply""" 

194 if isinstance(a, str) and isinstance(b, int) and len(a) * b > MAX_STR_LEN: 

195 raise RuntimeError("String length exceeded, max string length is {}".format(MAX_STR_LEN)) 

196 return a * b 

197 

198 

199def safe_add(a, b): 

200 """safe version of add""" 

201 if isinstance(a, str) and isinstance(b, str) and len(a) + len(b) > MAX_STR_LEN: 

202 raise RuntimeError("String length exceeded, max string length is {}".format(MAX_STR_LEN)) 

203 return a + b 

204 

205 

206def safe_lshift(a, b): 

207 """safe version of lshift""" 

208 if isinstance(b, numbers.Number): 

209 if b > MAX_SHIFT: 

210 raise RuntimeError("Invalid left shift, max left shift is {}".format(MAX_SHIFT)) 

211 elif HAS_NUMPY: 

212 if isinstance(b, numpy.ndarray): 

213 if numpy.nanmax(b) > MAX_SHIFT: 

214 raise RuntimeError("Invalid left shift, max left shift is {}".format(MAX_SHIFT)) 

215 return a << b 

216 

217 

218OPERATORS = {ast.Is: lambda a, b: a is b, 

219 ast.IsNot: lambda a, b: a is not b, 

220 ast.In: lambda a, b: a in b, 

221 ast.NotIn: lambda a, b: a not in b, 

222 ast.Add: safe_add, 

223 ast.BitAnd: lambda a, b: a & b, 

224 ast.BitOr: lambda a, b: a | b, 

225 ast.BitXor: lambda a, b: a ^ b, 

226 ast.Div: lambda a, b: a / b, 

227 ast.FloorDiv: lambda a, b: a // b, 

228 ast.LShift: safe_lshift, 

229 ast.RShift: lambda a, b: a >> b, 

230 ast.Mult: safe_mult, 

231 ast.Pow: safe_pow, 

232 ast.Sub: lambda a, b: a - b, 

233 ast.Mod: lambda a, b: a % b, 

234 ast.And: lambda a, b: a and b, 

235 ast.Or: lambda a, b: a or b, 

236 ast.Eq: lambda a, b: a == b, 

237 ast.Gt: lambda a, b: a > b, 

238 ast.GtE: lambda a, b: a >= b, 

239 ast.Lt: lambda a, b: a < b, 

240 ast.LtE: lambda a, b: a <= b, 

241 ast.NotEq: lambda a, b: a != b, 

242 ast.Invert: lambda a: ~a, 

243 ast.Not: lambda a: not a, 

244 ast.UAdd: lambda a: +a, 

245 ast.USub: lambda a: -a} 

246 

247 

248def valid_symbol_name(name): 

249 """Determine whether the input symbol name is a valid name. 

250 

251 Arguments 

252 --------- 

253 name : str 

254 name to check for validity. 

255 

256 Returns 

257 -------- 

258 valid : bool 

259 whether name is a a valid symbol name 

260 

261 This checks for Python reserved words and that the name matches 

262 the regular expression ``[a-zA-Z_][a-zA-Z0-9_]`` 

263 """ 

264 if name in RESERVED_WORDS: 

265 return False 

266 

267 gen = generate_tokens(io.BytesIO(name.encode('utf-8')).readline) 

268 typ, _, start, end, _ = next(gen) 

269 if typ == tk_ENCODING: 

270 typ, _, start, end, _ = next(gen) 

271 return typ == tk_NAME and start == (1, 0) and end == (1, len(name)) 

272 

273 

274def op2func(op): 

275 """Return function for operator nodes.""" 

276 return OPERATORS[op.__class__] 

277 

278 

279class Empty: 

280 """Empty class.""" 

281 

282 def __init__(self): 

283 """TODO: docstring in public method.""" 

284 pass 

285 

286 def __nonzero__(self): 

287 """TODO: docstring in magic method.""" 

288 return False 

289 

290 

291ReturnedNone = Empty() 

292 

293 

294class ExceptionHolder(object): 

295 """Basic exception handler.""" 

296 

297 def __init__(self, node, exc=None, msg='', expr=None, lineno=None): 

298 """TODO: docstring in public method.""" 

299 self.node = node 

300 self.expr = expr 

301 self.msg = msg 

302 self.exc = exc 

303 self.lineno = lineno 

304 self.exc_info = exc_info() 

305 if self.exc is None and self.exc_info[0] is not None: 

306 self.exc = self.exc_info[0] 

307 if self.msg == '' and self.exc_info[1] is not None: 

308 self.msg = self.exc_info[1] 

309 

310 def get_error(self): 

311 """Retrieve error data.""" 

312 col_offset = -1 

313 if self.node is not None: 

314 try: 

315 col_offset = self.node.col_offset 

316 except AttributeError: 

317 pass 

318 try: 

319 exc_name = self.exc.__name__ 

320 except AttributeError: 

321 exc_name = str(self.exc) 

322 if exc_name in (None, 'None'): 

323 exc_name = 'UnknownError' 

324 

325 out = [" %s" % self.expr] 

326 if col_offset > 0: 

327 out.append(" %s^^^" % ((col_offset)*' ')) 

328 out.append(str(self.msg)) 

329 return (exc_name, '\n'.join(out)) 

330 

331 

332class NameFinder(ast.NodeVisitor): 

333 """Find all symbol names used by a parsed node.""" 

334 

335 def __init__(self): 

336 """TODO: docstring in public method.""" 

337 self.names = [] 

338 ast.NodeVisitor.__init__(self) 

339 

340 def generic_visit(self, node): 

341 """TODO: docstring in public method.""" 

342 if node.__class__.__name__ == 'Name': 

343 if node.ctx.__class__ == ast.Load and node.id not in self.names: 

344 self.names.append(node.id) 

345 ast.NodeVisitor.generic_visit(self, node) 

346 

347builtins = __builtins__ 

348if not isinstance(builtins, dict): 

349 builtins = builtins.__dict__ 

350 

351def get_ast_names(astnode): 

352 """Return symbol Names from an AST node.""" 

353 finder = NameFinder() 

354 finder.generic_visit(astnode) 

355 return finder.names 

356 

357 

358def make_symbol_table(use_numpy=True, **kws): 

359 """Create a default symboltable, taking dict of user-defined symbols. 

360 

361 Arguments 

362 --------- 

363 numpy : bool, optional 

364 whether to include symbols from numpy 

365 kws : optional 

366 additional symbol name, value pairs to include in symbol table 

367 

368 Returns 

369 -------- 

370 symbol_table : dict 

371 a symbol table that can be used in `asteval.Interpereter` 

372 

373 """ 

374 symtable = {} 

375 

376 for sym in FROM_PY: 

377 if sym in builtins: 

378 symtable[sym] = builtins[sym] 

379 

380 for sym in FROM_MATH: 

381 if hasattr(math, sym): 

382 symtable[sym] = getattr(math, sym) 

383 

384 if HAS_NUMPY and use_numpy: 

385 for sym in FROM_NUMPY: 

386 if hasattr(numpy, sym): 

387 symtable[sym] = getattr(numpy, sym) 

388 for name, sym in NUMPY_RENAMES.items(): 

389 if hasattr(numpy, sym): 

390 symtable[name] = getattr(numpy, sym) 

391 

392 

393 symtable.update(LOCALFUNCS) 

394 symtable.update(kws) 

395 

396 return symtable