Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of Patsy 

2# Copyright (C) 2011 Nathaniel Smith <njs@pobox.com> 

3# See file LICENSE.txt for license information. 

4 

5# Utilities that require an over-intimate knowledge of Python's execution 

6# environment. 

7 

8# NB: if you add any __future__ imports to this file then you'll have to 

9# adjust the tests that deal with checking the caller's execution environment 

10# for __future__ flags! 

11 

12# These are made available in the patsy.* namespace 

13__all__ = ["EvalEnvironment", "EvalFactor"] 

14 

15import sys 

16import __future__ 

17import inspect 

18import tokenize 

19import ast 

20import numbers 

21import six 

22from patsy import PatsyError 

23from patsy.util import PushbackAdapter, no_pickling, assert_no_pickling 

24from patsy.tokens import (pretty_untokenize, normalize_token_spacing, 

25 python_tokenize) 

26from patsy.compat import call_and_wrap_exc 

27 

28def _all_future_flags(): 

29 flags = 0 

30 for feature_name in __future__.all_feature_names: 

31 feature = getattr(__future__, feature_name) 

32 if feature.getMandatoryRelease() > sys.version_info: 

33 flags |= feature.compiler_flag 

34 return flags 

35 

36_ALL_FUTURE_FLAGS = _all_future_flags() 

37 

38# This is just a minimal dict-like object that does lookup in a 'stack' of 

39# dicts -- first it checks the first, then the second, etc. Assignments go 

40# into an internal, zeroth dict. 

41class VarLookupDict(object): 

42 def __init__(self, dicts): 

43 self._dicts = [{}] + list(dicts) 

44 

45 def __getitem__(self, key): 

46 for d in self._dicts: 

47 try: 

48 return d[key] 

49 except KeyError: 

50 pass 

51 raise KeyError(key) 

52 

53 def __setitem__(self, key, value): 

54 self._dicts[0][key] = value 

55 

56 def __contains__(self, key): 

57 try: 

58 self[key] 

59 except KeyError: 

60 return False 

61 else: 

62 return True 

63 

64 def get(self, key, default=None): 

65 try: 

66 return self[key] 

67 except KeyError: 

68 return default 

69 

70 def __repr__(self): 

71 return "%s(%r)" % (self.__class__.__name__, self._dicts) 

72 

73 __getstate__ = no_pickling 

74 

75 

76def test_VarLookupDict(): 

77 d1 = {"a": 1} 

78 d2 = {"a": 2, "b": 3} 

79 ds = VarLookupDict([d1, d2]) 

80 assert ds["a"] == 1 

81 assert ds["b"] == 3 

82 assert "a" in ds 

83 assert "c" not in ds 

84 from nose.tools import assert_raises 

85 assert_raises(KeyError, ds.__getitem__, "c") 

86 ds["a"] = 10 

87 assert ds["a"] == 10 

88 assert d1["a"] == 1 

89 assert ds.get("c") is None 

90 assert isinstance(repr(ds), six.string_types) 

91 

92 assert_no_pickling(ds) 

93 

94def ast_names(code): 

95 """Iterator that yields all the (ast) names in a Python expression. 

96 

97 :arg code: A string containing a Python expression. 

98 """ 

99 # Syntax that allows new name bindings to be introduced is tricky to 

100 # handle here, so we just refuse to do so. 

101 disallowed_ast_nodes = (ast.Lambda, ast.ListComp, ast.GeneratorExp) 

102 if sys.version_info >= (2, 7): 

103 disallowed_ast_nodes += (ast.DictComp, ast.SetComp) 

104 

105 for node in ast.walk(ast.parse(code)): 

106 if isinstance(node, disallowed_ast_nodes): 

107 raise PatsyError("Lambda, list/dict/set comprehension, generator " 

108 "expression in patsy formula not currently supported.") 

109 if isinstance(node, ast.Name): 

110 yield node.id 

111 

112def test_ast_names(): 

113 test_data = [('np.log(x)', ['np', 'x']), 

114 ('x', ['x']), 

115 ('center(x + 1)', ['center', 'x']), 

116 ('dt.date.dt.month', ['dt'])] 

117 for code, expected in test_data: 

118 assert set(ast_names(code)) == set(expected) 

119 

120def test_ast_names_disallowed_nodes(): 

121 from nose.tools import assert_raises 

122 def list_ast_names(code): 

123 return list(ast_names(code)) 

124 assert_raises(PatsyError, list_ast_names, "lambda x: x + y") 

125 assert_raises(PatsyError, list_ast_names, "[x + 1 for x in range(10)]") 

126 assert_raises(PatsyError, list_ast_names, "(x + 1 for x in range(10))") 

127 if sys.version_info >= (2, 7): 

128 assert_raises(PatsyError, list_ast_names, "{x: True for x in range(10)}") 

129 assert_raises(PatsyError, list_ast_names, "{x + 1 for x in range(10)}") 

130 

131class EvalEnvironment(object): 

132 """Represents a Python execution environment. 

133 

134 Encapsulates a namespace for variable lookup and set of __future__ 

135 flags.""" 

136 def __init__(self, namespaces, flags=0): 

137 assert not flags & ~_ALL_FUTURE_FLAGS 

138 self._namespaces = list(namespaces) 

139 self.flags = flags 

140 

141 @property 

142 def namespace(self): 

143 """A dict-like object that can be used to look up variables accessible 

144 from the encapsulated environment.""" 

145 return VarLookupDict(self._namespaces) 

146 

147 def with_outer_namespace(self, outer_namespace): 

148 """Return a new EvalEnvironment with an extra namespace added. 

149 

150 This namespace will be used only for variables that are not found in 

151 any existing namespace, i.e., it is "outside" them all.""" 

152 return self.__class__(self._namespaces + [outer_namespace], 

153 self.flags) 

154 

155 def eval(self, expr, source_name="<string>", inner_namespace={}): 

156 """Evaluate some Python code in the encapsulated environment. 

157 

158 :arg expr: A string containing a Python expression. 

159 :arg source_name: A name for this string, for use in tracebacks. 

160 :arg inner_namespace: A dict-like object that will be checked first 

161 when `expr` attempts to access any variables. 

162 :returns: The value of `expr`. 

163 """ 

164 code = compile(expr, source_name, "eval", self.flags, False) 

165 return eval(code, {}, VarLookupDict([inner_namespace] 

166 + self._namespaces)) 

167 

168 @classmethod 

169 def capture(cls, eval_env=0, reference=0): 

170 """Capture an execution environment from the stack. 

171 

172 If `eval_env` is already an :class:`EvalEnvironment`, it is returned 

173 unchanged. Otherwise, we walk up the stack by ``eval_env + reference`` 

174 steps and capture that function's evaluation environment. 

175 

176 For ``eval_env=0`` and ``reference=0``, the default, this captures the 

177 stack frame of the function that calls :meth:`capture`. If ``eval_env 

178 + reference`` is 1, then we capture that function's caller, etc. 

179 

180 This somewhat complicated calling convention is designed to be 

181 convenient for functions which want to capture their caller's 

182 environment by default, but also allow explicit environments to be 

183 specified. See the second example. 

184 

185 Example:: 

186 

187 x = 1 

188 this_env = EvalEnvironment.capture() 

189 assert this_env.namespace["x"] == 1 

190 def child_func(): 

191 return EvalEnvironment.capture(1) 

192 this_env_from_child = child_func() 

193 assert this_env_from_child.namespace["x"] == 1 

194 

195 Example:: 

196 

197 # This function can be used like: 

198 # my_model(formula_like, data) 

199 # -> evaluates formula_like in caller's environment 

200 # my_model(formula_like, data, eval_env=1) 

201 # -> evaluates formula_like in caller's caller's environment 

202 # my_model(formula_like, data, eval_env=my_env) 

203 # -> evaluates formula_like in environment 'my_env' 

204 def my_model(formula_like, data, eval_env=0): 

205 eval_env = EvalEnvironment.capture(eval_env, reference=1) 

206 return model_setup_helper(formula_like, data, eval_env) 

207 

208 This is how :func:`dmatrix` works. 

209 

210 .. versionadded: 0.2.0 

211 The ``reference`` argument. 

212 """ 

213 if isinstance(eval_env, cls): 

214 return eval_env 

215 elif isinstance(eval_env, numbers.Integral): 

216 depth = eval_env + reference 

217 else: 

218 raise TypeError("Parameter 'eval_env' must be either an integer " 

219 "or an instance of patsy.EvalEnvironment.") 

220 frame = inspect.currentframe() 

221 try: 

222 for i in range(depth + 1): 

223 if frame is None: 

224 raise ValueError("call-stack is not that deep!") 

225 frame = frame.f_back 

226 return cls([frame.f_locals, frame.f_globals], 

227 frame.f_code.co_flags & _ALL_FUTURE_FLAGS) 

228 # The try/finally is important to avoid a potential reference cycle -- 

229 # any exception traceback will carry a reference to *our* frame, which 

230 # contains a reference to our local variables, which would otherwise 

231 # carry a reference to some parent frame, where the exception was 

232 # caught...: 

233 finally: 

234 del frame 

235 

236 def subset(self, names): 

237 """Creates a new, flat EvalEnvironment that contains only 

238 the variables specified.""" 

239 vld = VarLookupDict(self._namespaces) 

240 new_ns = dict((name, vld[name]) for name in names) 

241 return EvalEnvironment([new_ns], self.flags) 

242 

243 def _namespace_ids(self): 

244 return [id(n) for n in self._namespaces] 

245 

246 def __eq__(self, other): 

247 return (isinstance(other, EvalEnvironment) 

248 and self.flags == other.flags 

249 and self._namespace_ids() == other._namespace_ids()) 

250 

251 def __ne__(self, other): 

252 return not self == other 

253 

254 def __hash__(self): 

255 return hash((EvalEnvironment, 

256 self.flags, 

257 tuple(self._namespace_ids()))) 

258 

259 __getstate__ = no_pickling 

260 

261def _a(): # pragma: no cover 

262 _a = 1 

263 return _b() 

264 

265def _b(): # pragma: no cover 

266 _b = 1 

267 return _c() 

268 

269def _c(): # pragma: no cover 

270 _c = 1 

271 return [EvalEnvironment.capture(), 

272 EvalEnvironment.capture(0), 

273 EvalEnvironment.capture(1), 

274 EvalEnvironment.capture(0, reference=1), 

275 EvalEnvironment.capture(2), 

276 EvalEnvironment.capture(0, 2), 

277 ] 

278 

279def test_EvalEnvironment_capture_namespace(): 

280 c0, c, b1, b2, a1, a2 = _a() 

281 assert "test_EvalEnvironment_capture_namespace" in c0.namespace 

282 assert "test_EvalEnvironment_capture_namespace" in c.namespace 

283 assert "test_EvalEnvironment_capture_namespace" in b1.namespace 

284 assert "test_EvalEnvironment_capture_namespace" in b2.namespace 

285 assert "test_EvalEnvironment_capture_namespace" in a1.namespace 

286 assert "test_EvalEnvironment_capture_namespace" in a2.namespace 

287 assert c0.namespace["_c"] == 1 

288 assert c.namespace["_c"] == 1 

289 assert b1.namespace["_b"] == 1 

290 assert b2.namespace["_b"] == 1 

291 assert a1.namespace["_a"] == 1 

292 assert a2.namespace["_a"] == 1 

293 assert b1.namespace["_c"] is _c 

294 assert b2.namespace["_c"] is _c 

295 from nose.tools import assert_raises 

296 assert_raises(ValueError, EvalEnvironment.capture, 10 ** 6) 

297 

298 assert EvalEnvironment.capture(b1) is b1 

299 

300 assert_raises(TypeError, EvalEnvironment.capture, 1.2) 

301 

302 assert_no_pickling(EvalEnvironment.capture()) 

303 

304def test_EvalEnvironment_capture_flags(): 

305 if sys.version_info >= (3,): 

306 # This is the only __future__ feature currently usable in Python 

307 # 3... fortunately it is probably not going anywhere. 

308 TEST_FEATURE = "barry_as_FLUFL" 

309 else: 

310 TEST_FEATURE = "division" 

311 test_flag = getattr(__future__, TEST_FEATURE).compiler_flag 

312 assert test_flag & _ALL_FUTURE_FLAGS 

313 source = ("def f():\n" 

314 " in_f = 'hi from f'\n" 

315 " global RETURN_INNER, RETURN_OUTER, RETURN_INNER_FROM_OUTER\n" 

316 " RETURN_INNER = EvalEnvironment.capture(0)\n" 

317 " RETURN_OUTER = call_capture_0()\n" 

318 " RETURN_INNER_FROM_OUTER = call_capture_1()\n" 

319 "f()\n") 

320 code = compile(source, "<test string>", "exec", 0, 1) 

321 env = {"EvalEnvironment": EvalEnvironment, 

322 "call_capture_0": lambda: EvalEnvironment.capture(0), 

323 "call_capture_1": lambda: EvalEnvironment.capture(1), 

324 } 

325 env2 = dict(env) 

326 six.exec_(code, env) 

327 assert env["RETURN_INNER"].namespace["in_f"] == "hi from f" 

328 assert env["RETURN_INNER_FROM_OUTER"].namespace["in_f"] == "hi from f" 

329 assert "in_f" not in env["RETURN_OUTER"].namespace 

330 assert env["RETURN_INNER"].flags & _ALL_FUTURE_FLAGS == 0 

331 assert env["RETURN_OUTER"].flags & _ALL_FUTURE_FLAGS == 0 

332 assert env["RETURN_INNER_FROM_OUTER"].flags & _ALL_FUTURE_FLAGS == 0 

333 

334 code2 = compile(("from __future__ import %s\n" % (TEST_FEATURE,)) 

335 + source, 

336 "<test string 2>", "exec", 0, 1) 

337 six.exec_(code2, env2) 

338 assert env2["RETURN_INNER"].namespace["in_f"] == "hi from f" 

339 assert env2["RETURN_INNER_FROM_OUTER"].namespace["in_f"] == "hi from f" 

340 assert "in_f" not in env2["RETURN_OUTER"].namespace 

341 assert env2["RETURN_INNER"].flags & _ALL_FUTURE_FLAGS == test_flag 

342 assert env2["RETURN_OUTER"].flags & _ALL_FUTURE_FLAGS == 0 

343 assert env2["RETURN_INNER_FROM_OUTER"].flags & _ALL_FUTURE_FLAGS == test_flag 

344 

345def test_EvalEnvironment_eval_namespace(): 

346 env = EvalEnvironment([{"a": 1}]) 

347 assert env.eval("2 * a") == 2 

348 assert env.eval("2 * a", inner_namespace={"a": 2}) == 4 

349 from nose.tools import assert_raises 

350 assert_raises(NameError, env.eval, "2 * b") 

351 a = 3 

352 env2 = EvalEnvironment.capture(0) 

353 assert env2.eval("2 * a") == 6 

354 

355 env3 = env.with_outer_namespace({"a": 10, "b": 3}) 

356 assert env3.eval("2 * a") == 2 

357 assert env3.eval("2 * b") == 6 

358 

359def test_EvalEnvironment_eval_flags(): 

360 from nose.tools import assert_raises 

361 if sys.version_info >= (3,): 

362 # This joke __future__ statement replaces "!=" with "<>": 

363 # http://www.python.org/dev/peps/pep-0401/ 

364 test_flag = __future__.barry_as_FLUFL.compiler_flag 

365 assert test_flag & _ALL_FUTURE_FLAGS 

366 

367 env = EvalEnvironment([{"a": 11}], flags=0) 

368 assert env.eval("a != 0") == True 

369 assert_raises(SyntaxError, env.eval, "a <> 0") 

370 assert env.subset(["a"]).flags == 0 

371 assert env.with_outer_namespace({"b": 10}).flags == 0 

372 

373 env2 = EvalEnvironment([{"a": 11}], flags=test_flag) 

374 assert env2.eval("a <> 0") == True 

375 assert_raises(SyntaxError, env2.eval, "a != 0") 

376 assert env2.subset(["a"]).flags == test_flag 

377 assert env2.with_outer_namespace({"b": 10}).flags == test_flag 

378 else: 

379 test_flag = __future__.division.compiler_flag 

380 assert test_flag & _ALL_FUTURE_FLAGS 

381 

382 env = EvalEnvironment([{"a": 11}], flags=0) 

383 assert env.eval("a / 2") == 11 // 2 == 5 

384 assert env.subset(["a"]).flags == 0 

385 assert env.with_outer_namespace({"b": 10}).flags == 0 

386 

387 env2 = EvalEnvironment([{"a": 11}], flags=test_flag) 

388 assert env2.eval("a / 2") == 11 * 1. / 2 != 5 

389 env2.subset(["a"]).flags == test_flag 

390 assert env2.with_outer_namespace({"b": 10}).flags == test_flag 

391 

392def test_EvalEnvironment_subset(): 

393 env = EvalEnvironment([{"a": 1}, {"b": 2}, {"c": 3}]) 

394 

395 subset_a = env.subset(["a"]) 

396 assert subset_a.eval("a") == 1 

397 from nose.tools import assert_raises 

398 assert_raises(NameError, subset_a.eval, "b") 

399 assert_raises(NameError, subset_a.eval, "c") 

400 

401 subset_bc = env.subset(["b", "c"]) 

402 assert subset_bc.eval("b * c") == 6 

403 assert_raises(NameError, subset_bc.eval, "a") 

404 

405def test_EvalEnvironment_eq(): 

406 # Two environments are eq only if they refer to exactly the same 

407 # global/local dicts 

408 env1 = EvalEnvironment.capture(0) 

409 env2 = EvalEnvironment.capture(0) 

410 assert env1 == env2 

411 assert hash(env1) == hash(env2) 

412 capture_local_env = lambda: EvalEnvironment.capture(0) 

413 env3 = capture_local_env() 

414 env4 = capture_local_env() 

415 assert env3 != env4 

416 

417_builtins_dict = {} 

418six.exec_("from patsy.builtins import *", {}, _builtins_dict) 

419# This is purely to make the existence of patsy.builtins visible to systems 

420# like py2app and py2exe. It's basically free, since the above line guarantees 

421# that patsy.builtins will be present in sys.modules in any case. 

422import patsy.builtins 

423 

424class EvalFactor(object): 

425 def __init__(self, code, origin=None): 

426 """A factor class that executes arbitrary Python code and supports 

427 stateful transforms. 

428 

429 :arg code: A string containing a Python expression, that will be 

430 evaluated to produce this factor's value. 

431 

432 This is the standard factor class that is used when parsing formula 

433 strings and implements the standard stateful transform processing. See 

434 :ref:`stateful-transforms` and :ref:`expert-model-specification`. 

435 

436 Two EvalFactor's are considered equal (e.g., for purposes of 

437 redundancy detection) if they contain the same token stream. Basically 

438 this means that the source code must be identical except for 

439 whitespace:: 

440 

441 assert EvalFactor("a + b") == EvalFactor("a+b") 

442 assert EvalFactor("a + b") != EvalFactor("b + a") 

443 """ 

444 

445 # For parsed formulas, the code will already have been normalized by 

446 # the parser. But let's normalize anyway, so we can be sure of having 

447 # consistent semantics for __eq__ and __hash__. 

448 self.code = normalize_token_spacing(code) 

449 self.origin = origin 

450 

451 def name(self): 

452 return self.code 

453 

454 def __repr__(self): 

455 return "%s(%r)" % (self.__class__.__name__, self.code) 

456 

457 def __eq__(self, other): 

458 return (isinstance(other, EvalFactor) 

459 and self.code == other.code) 

460 

461 def __ne__(self, other): 

462 return not self == other 

463 

464 def __hash__(self): 

465 return hash((EvalFactor, self.code)) 

466 

467 def memorize_passes_needed(self, state, eval_env): 

468 # 'state' is just an empty dict which we can do whatever we want with, 

469 # and that will be passed back to later memorize functions 

470 state["transforms"] = {} 

471 

472 eval_env = eval_env.with_outer_namespace(_builtins_dict) 

473 env_namespace = eval_env.namespace 

474 subset_names = [name for name in ast_names(self.code) 

475 if name in env_namespace] 

476 eval_env = eval_env.subset(subset_names) 

477 state["eval_env"] = eval_env 

478 

479 # example code: == "2 * center(x)" 

480 i = [0] 

481 def new_name_maker(token): 

482 value = eval_env.namespace.get(token) 

483 if hasattr(value, "__patsy_stateful_transform__"): 

484 obj_name = "_patsy_stobj%s__%s__" % (i[0], token) 

485 i[0] += 1 

486 obj = value.__patsy_stateful_transform__() 

487 state["transforms"][obj_name] = obj 

488 return obj_name + ".transform" 

489 else: 

490 return token 

491 # example eval_code: == "2 * _patsy_stobj0__center__.transform(x)" 

492 eval_code = replace_bare_funcalls(self.code, new_name_maker) 

493 state["eval_code"] = eval_code 

494 # paranoia: verify that none of our new names appeared anywhere in the 

495 # original code 

496 if has_bare_variable_reference(state["transforms"], self.code): 

497 raise PatsyError("names of this form are reserved for " 

498 "internal use (%s)" % (token,), token.origin) 

499 # Pull out all the '_patsy_stobj0__center__.transform(x)' pieces 

500 # to make '_patsy_stobj0__center__.memorize_chunk(x)' pieces 

501 state["memorize_code"] = {} 

502 for obj_name in state["transforms"]: 

503 transform_calls = capture_obj_method_calls(obj_name, eval_code) 

504 assert len(transform_calls) == 1 

505 transform_call = transform_calls[0] 

506 transform_call_name, transform_call_code = transform_call 

507 assert transform_call_name == obj_name + ".transform" 

508 assert transform_call_code.startswith(transform_call_name + "(") 

509 memorize_code = (obj_name 

510 + ".memorize_chunk" 

511 + transform_call_code[len(transform_call_name):]) 

512 state["memorize_code"][obj_name] = memorize_code 

513 # Then sort the codes into bins, so that every item in bin number i 

514 # depends only on items in bin (i-1) or less. (By 'depends', we mean 

515 # that in something like: 

516 # spline(center(x)) 

517 # we have to first run: 

518 # center.memorize_chunk(x) 

519 # then 

520 # center.memorize_finish(x) 

521 # and only then can we run: 

522 # spline.memorize_chunk(center.transform(x)) 

523 # Since all of our objects have unique names, figuring out who 

524 # depends on who is pretty easy -- we just check whether the 

525 # memorization code for spline: 

526 # spline.memorize_chunk(center.transform(x)) 

527 # mentions the variable 'center' (which in the example, of course, it 

528 # does). 

529 pass_bins = [] 

530 unsorted = set(state["transforms"]) 

531 while unsorted: 

532 pass_bin = set() 

533 for obj_name in unsorted: 

534 other_objs = unsorted.difference([obj_name]) 

535 memorize_code = state["memorize_code"][obj_name] 

536 if not has_bare_variable_reference(other_objs, memorize_code): 

537 pass_bin.add(obj_name) 

538 assert pass_bin 

539 unsorted.difference_update(pass_bin) 

540 pass_bins.append(pass_bin) 

541 state["pass_bins"] = pass_bins 

542 

543 return len(pass_bins) 

544 

545 def _eval(self, code, memorize_state, data): 

546 inner_namespace = VarLookupDict([data, memorize_state["transforms"]]) 

547 return call_and_wrap_exc("Error evaluating factor", 

548 self, 

549 memorize_state["eval_env"].eval, 

550 code, 

551 inner_namespace=inner_namespace) 

552 

553 def memorize_chunk(self, state, which_pass, data): 

554 for obj_name in state["pass_bins"][which_pass]: 

555 self._eval(state["memorize_code"][obj_name], 

556 state, 

557 data) 

558 

559 def memorize_finish(self, state, which_pass): 

560 for obj_name in state["pass_bins"][which_pass]: 

561 state["transforms"][obj_name].memorize_finish() 

562 

563 def eval(self, memorize_state, data): 

564 return self._eval(memorize_state["eval_code"], 

565 memorize_state, 

566 data) 

567 

568 __getstate__ = no_pickling 

569 

570def test_EvalFactor_basics(): 

571 e = EvalFactor("a+b") 

572 assert e.code == "a + b" 

573 assert e.name() == "a + b" 

574 e2 = EvalFactor("a +b", origin="asdf") 

575 assert e == e2 

576 assert hash(e) == hash(e2) 

577 assert e.origin is None 

578 assert e2.origin == "asdf" 

579 

580 assert_no_pickling(e) 

581 

582def test_EvalFactor_memorize_passes_needed(): 

583 from patsy.state import stateful_transform 

584 foo = stateful_transform(lambda: "FOO-OBJ") 

585 bar = stateful_transform(lambda: "BAR-OBJ") 

586 quux = stateful_transform(lambda: "QUUX-OBJ") 

587 e = EvalFactor("foo(x) + bar(foo(y)) + quux(z, w)") 

588 

589 state = {} 

590 eval_env = EvalEnvironment.capture(0) 

591 passes = e.memorize_passes_needed(state, eval_env) 

592 print(passes) 

593 print(state) 

594 assert passes == 2 

595 for name in ["foo", "bar", "quux"]: 

596 assert state["eval_env"].namespace[name] is locals()[name] 

597 for name in ["w", "x", "y", "z", "e", "state"]: 

598 assert name not in state["eval_env"].namespace 

599 assert state["transforms"] == {"_patsy_stobj0__foo__": "FOO-OBJ", 

600 "_patsy_stobj1__bar__": "BAR-OBJ", 

601 "_patsy_stobj2__foo__": "FOO-OBJ", 

602 "_patsy_stobj3__quux__": "QUUX-OBJ"} 

603 assert (state["eval_code"] 

604 == "_patsy_stobj0__foo__.transform(x)" 

605 " + _patsy_stobj1__bar__.transform(" 

606 "_patsy_stobj2__foo__.transform(y))" 

607 " + _patsy_stobj3__quux__.transform(z, w)") 

608 

609 assert (state["memorize_code"] 

610 == {"_patsy_stobj0__foo__": 

611 "_patsy_stobj0__foo__.memorize_chunk(x)", 

612 "_patsy_stobj1__bar__": 

613 "_patsy_stobj1__bar__.memorize_chunk(_patsy_stobj2__foo__.transform(y))", 

614 "_patsy_stobj2__foo__": 

615 "_patsy_stobj2__foo__.memorize_chunk(y)", 

616 "_patsy_stobj3__quux__": 

617 "_patsy_stobj3__quux__.memorize_chunk(z, w)", 

618 }) 

619 assert state["pass_bins"] == [set(["_patsy_stobj0__foo__", 

620 "_patsy_stobj2__foo__", 

621 "_patsy_stobj3__quux__"]), 

622 set(["_patsy_stobj1__bar__"])] 

623 

624class _MockTransform(object): 

625 # Adds up all memorized data, then subtracts that sum from each datum 

626 def __init__(self): 

627 self._sum = 0 

628 self._memorize_chunk_called = 0 

629 self._memorize_finish_called = 0 

630 

631 def memorize_chunk(self, data): 

632 self._memorize_chunk_called += 1 

633 import numpy as np 

634 self._sum += np.sum(data) 

635 

636 def memorize_finish(self): 

637 self._memorize_finish_called += 1 

638 

639 def transform(self, data): 

640 return data - self._sum 

641 

642def test_EvalFactor_end_to_end(): 

643 from patsy.state import stateful_transform 

644 foo = stateful_transform(_MockTransform) 

645 e = EvalFactor("foo(x) + foo(foo(y))") 

646 state = {} 

647 eval_env = EvalEnvironment.capture(0) 

648 passes = e.memorize_passes_needed(state, eval_env) 

649 print(passes) 

650 print(state) 

651 assert passes == 2 

652 assert state["eval_env"].namespace["foo"] is foo 

653 for name in ["x", "y", "e", "state"]: 

654 assert name not in state["eval_env"].namespace 

655 import numpy as np 

656 e.memorize_chunk(state, 0, 

657 {"x": np.array([1, 2]), 

658 "y": np.array([10, 11])}) 

659 assert state["transforms"]["_patsy_stobj0__foo__"]._memorize_chunk_called == 1 

660 assert state["transforms"]["_patsy_stobj2__foo__"]._memorize_chunk_called == 1 

661 e.memorize_chunk(state, 0, {"x": np.array([12, -10]), 

662 "y": np.array([100, 3])}) 

663 assert state["transforms"]["_patsy_stobj0__foo__"]._memorize_chunk_called == 2 

664 assert state["transforms"]["_patsy_stobj2__foo__"]._memorize_chunk_called == 2 

665 assert state["transforms"]["_patsy_stobj0__foo__"]._memorize_finish_called == 0 

666 assert state["transforms"]["_patsy_stobj2__foo__"]._memorize_finish_called == 0 

667 e.memorize_finish(state, 0) 

668 assert state["transforms"]["_patsy_stobj0__foo__"]._memorize_finish_called == 1 

669 assert state["transforms"]["_patsy_stobj2__foo__"]._memorize_finish_called == 1 

670 assert state["transforms"]["_patsy_stobj1__foo__"]._memorize_chunk_called == 0 

671 assert state["transforms"]["_patsy_stobj1__foo__"]._memorize_finish_called == 0 

672 e.memorize_chunk(state, 1, {"x": np.array([1, 2]), 

673 "y": np.array([10, 11])}) 

674 e.memorize_chunk(state, 1, {"x": np.array([12, -10]), 

675 "y": np.array([100, 3])}) 

676 e.memorize_finish(state, 1) 

677 for transform in six.itervalues(state["transforms"]): 

678 assert transform._memorize_chunk_called == 2 

679 assert transform._memorize_finish_called == 1 

680 # sums: 

681 # 0: 1 + 2 + 12 + -10 == 5 

682 # 2: 10 + 11 + 100 + 3 == 124 

683 # 1: (10 - 124) + (11 - 124) + (100 - 124) + (3 - 124) == -372 

684 # results: 

685 # 0: -4, -3, 7, -15 

686 # 2: -114, -113, -24, -121 

687 # 1: 258, 259, 348, 251 

688 # 0 + 1: 254, 256, 355, 236 

689 assert np.all(e.eval(state, 

690 {"x": np.array([1, 2, 12, -10]), 

691 "y": np.array([10, 11, 100, 3])}) 

692 == [254, 256, 355, 236]) 

693 

694def annotated_tokens(code): 

695 prev_was_dot = False 

696 it = PushbackAdapter(python_tokenize(code)) 

697 for (token_type, token, origin) in it: 

698 props = {} 

699 props["bare_ref"] = (not prev_was_dot and token_type == tokenize.NAME) 

700 props["bare_funcall"] = (props["bare_ref"] 

701 and it.has_more() and it.peek()[1] == "(") 

702 yield (token_type, token, origin, props) 

703 prev_was_dot = (token == ".") 

704 

705def test_annotated_tokens(): 

706 tokens_without_origins = [(token_type, token, props) 

707 for (token_type, token, origin, props) 

708 in (annotated_tokens("a(b) + c.d"))] 

709 assert (tokens_without_origins 

710 == [(tokenize.NAME, "a", {"bare_ref": True, "bare_funcall": True}), 

711 (tokenize.OP, "(", {"bare_ref": False, "bare_funcall": False}), 

712 (tokenize.NAME, "b", {"bare_ref": True, "bare_funcall": False}), 

713 (tokenize.OP, ")", {"bare_ref": False, "bare_funcall": False}), 

714 (tokenize.OP, "+", {"bare_ref": False, "bare_funcall": False}), 

715 (tokenize.NAME, "c", {"bare_ref": True, "bare_funcall": False}), 

716 (tokenize.OP, ".", {"bare_ref": False, "bare_funcall": False}), 

717 (tokenize.NAME, "d", 

718 {"bare_ref": False, "bare_funcall": False}), 

719 ]) 

720 

721 # This was a bug: 

722 assert len(list(annotated_tokens("x"))) == 1 

723 

724def has_bare_variable_reference(names, code): 

725 for (_, token, _, props) in annotated_tokens(code): 

726 if props["bare_ref"] and token in names: 

727 return True 

728 return False 

729 

730def replace_bare_funcalls(code, replacer): 

731 tokens = [] 

732 for (token_type, token, origin, props) in annotated_tokens(code): 

733 if props["bare_ref"] and props["bare_funcall"]: 

734 token = replacer(token) 

735 tokens.append((token_type, token)) 

736 return pretty_untokenize(tokens) 

737 

738def test_replace_bare_funcalls(): 

739 def replacer1(token): 

740 return {"a": "b", "foo": "_internal.foo.process"}.get(token, token) 

741 def t1(code, expected): 

742 replaced = replace_bare_funcalls(code, replacer1) 

743 print("%r -> %r" % (code, replaced)) 

744 print("(wanted %r)" % (expected,)) 

745 assert replaced == expected 

746 t1("foobar()", "foobar()") 

747 t1("a()", "b()") 

748 t1("foobar.a()", "foobar.a()") 

749 t1("foo()", "_internal.foo.process()") 

750 t1("a + 1", "a + 1") 

751 t1("b() + a() * x[foo(2 ** 3)]", 

752 "b() + b() * x[_internal.foo.process(2 ** 3)]") 

753 

754class _FuncallCapturer(object): 

755 # captures the next funcall 

756 def __init__(self, start_token_type, start_token): 

757 self.func = [start_token] 

758 self.tokens = [(start_token_type, start_token)] 

759 self.paren_depth = 0 

760 self.started = False 

761 self.done = False 

762 

763 def add_token(self, token_type, token): 

764 if self.done: 

765 return 

766 self.tokens.append((token_type, token)) 

767 if token in ["(", "{", "["]: 

768 self.paren_depth += 1 

769 if token in [")", "}", "]"]: 

770 self.paren_depth -= 1 

771 assert self.paren_depth >= 0 

772 if not self.started: 

773 if token == "(": 

774 self.started = True 

775 else: 

776 assert token_type == tokenize.NAME or token == "." 

777 self.func.append(token) 

778 if self.started and self.paren_depth == 0: 

779 self.done = True 

780 

781# This is not a very general function -- it assumes that all references to the 

782# given object are of the form '<obj_name>.something(method call)'. 

783def capture_obj_method_calls(obj_name, code): 

784 capturers = [] 

785 for (token_type, token, origin, props) in annotated_tokens(code): 

786 for capturer in capturers: 

787 capturer.add_token(token_type, token) 

788 if props["bare_ref"] and token == obj_name: 

789 capturers.append(_FuncallCapturer(token_type, token)) 

790 return [("".join(capturer.func), pretty_untokenize(capturer.tokens)) 

791 for capturer in capturers] 

792 

793def test_capture_obj_method_calls(): 

794 assert (capture_obj_method_calls("foo", "a + foo.baz(bar) + b.c(d)") 

795 == [("foo.baz", "foo.baz(bar)")]) 

796 assert (capture_obj_method_calls("b", "a + foo.baz(bar) + b.c(d)") 

797 == [("b.c", "b.c(d)")]) 

798 assert (capture_obj_method_calls("foo", "foo.bar(foo.baz(quux))") 

799 == [("foo.bar", "foo.bar(foo.baz(quux))"), 

800 ("foo.baz", "foo.baz(quux)")]) 

801 assert (capture_obj_method_calls("bar", "foo[bar.baz(x(z[asdf])) ** 2]") 

802 == [("bar.baz", "bar.baz(x(z[asdf]))")])