Coverage for /opt/homebrew/lib/python3.11/site-packages/_pytest/assertion/rewrite.py: 33%

624 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-05-04 13:14 +0700

1"""Rewrite assertion AST to produce nice error messages.""" 

2import ast 

3import errno 

4import functools 

5import importlib.abc 

6import importlib.machinery 

7import importlib.util 

8import io 

9import itertools 

10import marshal 

11import os 

12import struct 

13import sys 

14import tokenize 

15import types 

16from pathlib import Path 

17from pathlib import PurePath 

18from typing import Callable 

19from typing import Dict 

20from typing import IO 

21from typing import Iterable 

22from typing import Iterator 

23from typing import List 

24from typing import Optional 

25from typing import Sequence 

26from typing import Set 

27from typing import Tuple 

28from typing import TYPE_CHECKING 

29from typing import Union 

30 

31from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE 

32from _pytest._io.saferepr import saferepr 

33from _pytest._version import version 

34from _pytest.assertion import util 

35from _pytest.assertion.util import ( # noqa: F401 

36 format_explanation as _format_explanation, 

37) 

38from _pytest.config import Config 

39from _pytest.main import Session 

40from _pytest.pathlib import absolutepath 

41from _pytest.pathlib import fnmatch_ex 

42from _pytest.stash import StashKey 

43 

44if TYPE_CHECKING: 

45 from _pytest.assertion import AssertionState 

46 

47 

48assertstate_key = StashKey["AssertionState"]() 

49 

50 

51# pytest caches rewritten pycs in pycache dirs 

52PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}" 

53PYC_EXT = ".py" + (__debug__ and "c" or "o") 

54PYC_TAIL = "." + PYTEST_TAG + PYC_EXT 

55 

56 

57class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): 

58 """PEP302/PEP451 import hook which rewrites asserts.""" 

59 

60 def __init__(self, config: Config) -> None: 

61 self.config = config 

62 try: 

63 self.fnpats = config.getini("python_files") 

64 except ValueError: 

65 self.fnpats = ["test_*.py", "*_test.py"] 

66 self.session: Optional[Session] = None 

67 self._rewritten_names: Dict[str, Path] = {} 

68 self._must_rewrite: Set[str] = set() 

69 # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, 

70 # which might result in infinite recursion (#3506) 

71 self._writing_pyc = False 

72 self._basenames_to_check_rewrite = {"conftest"} 

73 self._marked_for_rewrite_cache: Dict[str, bool] = {} 

74 self._session_paths_checked = False 

75 

76 def set_session(self, session: Optional[Session]) -> None: 

77 self.session = session 

78 self._session_paths_checked = False 

79 

80 # Indirection so we can mock calls to find_spec originated from the hook during testing 

81 _find_spec = importlib.machinery.PathFinder.find_spec 

82 

83 def find_spec( 

84 self, 

85 name: str, 

86 path: Optional[Sequence[Union[str, bytes]]] = None, 

87 target: Optional[types.ModuleType] = None, 

88 ) -> Optional[importlib.machinery.ModuleSpec]: 

89 if self._writing_pyc: 

90 return None 

91 state = self.config.stash[assertstate_key] 

92 if self._early_rewrite_bailout(name, state): 

93 return None 

94 state.trace("find_module called for: %s" % name) 

95 

96 # Type ignored because mypy is confused about the `self` binding here. 

97 spec = self._find_spec(name, path) # type: ignore 

98 if ( 

99 # the import machinery could not find a file to import 

100 spec is None 

101 # this is a namespace package (without `__init__.py`) 

102 # there's nothing to rewrite there 

103 or spec.origin is None 

104 # we can only rewrite source files 

105 or not isinstance(spec.loader, importlib.machinery.SourceFileLoader) 

106 # if the file doesn't exist, we can't rewrite it 

107 or not os.path.exists(spec.origin) 

108 ): 

109 return None 

110 else: 

111 fn = spec.origin 

112 

113 if not self._should_rewrite(name, fn, state): 

114 return None 

115 

116 return importlib.util.spec_from_file_location( 

117 name, 

118 fn, 

119 loader=self, 

120 submodule_search_locations=spec.submodule_search_locations, 

121 ) 

122 

123 def create_module( 

124 self, spec: importlib.machinery.ModuleSpec 

125 ) -> Optional[types.ModuleType]: 

126 return None # default behaviour is fine 

127 

128 def exec_module(self, module: types.ModuleType) -> None: 

129 assert module.__spec__ is not None 

130 assert module.__spec__.origin is not None 

131 fn = Path(module.__spec__.origin) 

132 state = self.config.stash[assertstate_key] 

133 

134 self._rewritten_names[module.__name__] = fn 

135 

136 # The requested module looks like a test file, so rewrite it. This is 

137 # the most magical part of the process: load the source, rewrite the 

138 # asserts, and load the rewritten source. We also cache the rewritten 

139 # module code in a special pyc. We must be aware of the possibility of 

140 # concurrent pytest processes rewriting and loading pycs. To avoid 

141 # tricky race conditions, we maintain the following invariant: The 

142 # cached pyc is always a complete, valid pyc. Operations on it must be 

143 # atomic. POSIX's atomic rename comes in handy. 

144 write = not sys.dont_write_bytecode 

145 cache_dir = get_cache_dir(fn) 

146 if write: 

147 ok = try_makedirs(cache_dir) 

148 if not ok: 

149 write = False 

150 state.trace(f"read only directory: {cache_dir}") 

151 

152 cache_name = fn.name[:-3] + PYC_TAIL 

153 pyc = cache_dir / cache_name 

154 # Notice that even if we're in a read-only directory, I'm going 

155 # to check for a cached pyc. This may not be optimal... 

156 co = _read_pyc(fn, pyc, state.trace) 

157 if co is None: 

158 state.trace(f"rewriting {fn!r}") 

159 source_stat, co = _rewrite_test(fn, self.config) 

160 if write: 

161 self._writing_pyc = True 

162 try: 

163 _write_pyc(state, co, source_stat, pyc) 

164 finally: 

165 self._writing_pyc = False 

166 else: 

167 state.trace(f"found cached rewritten pyc for {fn}") 

168 exec(co, module.__dict__) 

169 

170 def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool: 

171 """A fast way to get out of rewriting modules. 

172 

173 Profiling has shown that the call to PathFinder.find_spec (inside of 

174 the find_spec from this class) is a major slowdown, so, this method 

175 tries to filter what we're sure won't be rewritten before getting to 

176 it. 

177 """ 

178 if self.session is not None and not self._session_paths_checked: 

179 self._session_paths_checked = True 

180 for initial_path in self.session._initialpaths: 

181 # Make something as c:/projects/my_project/path.py -> 

182 # ['c:', 'projects', 'my_project', 'path.py'] 

183 parts = str(initial_path).split(os.path.sep) 

184 # add 'path' to basenames to be checked. 

185 self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0]) 

186 

187 # Note: conftest already by default in _basenames_to_check_rewrite. 

188 parts = name.split(".") 

189 if parts[-1] in self._basenames_to_check_rewrite: 

190 return False 

191 

192 # For matching the name it must be as if it was a filename. 

193 path = PurePath(*parts).with_suffix(".py") 

194 

195 for pat in self.fnpats: 

196 # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based 

197 # on the name alone because we need to match against the full path 

198 if os.path.dirname(pat): 

199 return False 

200 if fnmatch_ex(pat, path): 

201 return False 

202 

203 if self._is_marked_for_rewrite(name, state): 

204 return False 

205 

206 state.trace(f"early skip of rewriting module: {name}") 

207 return True 

208 

209 def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool: 

210 # always rewrite conftest files 

211 if os.path.basename(fn) == "conftest.py": 

212 state.trace(f"rewriting conftest file: {fn!r}") 

213 return True 

214 

215 if self.session is not None: 

216 if self.session.isinitpath(absolutepath(fn)): 

217 state.trace(f"matched test file (was specified on cmdline): {fn!r}") 

218 return True 

219 

220 # modules not passed explicitly on the command line are only 

221 # rewritten if they match the naming convention for test files 

222 fn_path = PurePath(fn) 

223 for pat in self.fnpats: 

224 if fnmatch_ex(pat, fn_path): 

225 state.trace(f"matched test file {fn!r}") 

226 return True 

227 

228 return self._is_marked_for_rewrite(name, state) 

229 

230 def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool: 

231 try: 

232 return self._marked_for_rewrite_cache[name] 

233 except KeyError: 

234 for marked in self._must_rewrite: 

235 if name == marked or name.startswith(marked + "."): 

236 state.trace(f"matched marked file {name!r} (from {marked!r})") 

237 self._marked_for_rewrite_cache[name] = True 

238 return True 

239 

240 self._marked_for_rewrite_cache[name] = False 

241 return False 

242 

243 def mark_rewrite(self, *names: str) -> None: 

244 """Mark import names as needing to be rewritten. 

245 

246 The named module or package as well as any nested modules will 

247 be rewritten on import. 

248 """ 

249 already_imported = ( 

250 set(names).intersection(sys.modules).difference(self._rewritten_names) 

251 ) 

252 for name in already_imported: 

253 mod = sys.modules[name] 

254 if not AssertionRewriter.is_rewrite_disabled( 

255 mod.__doc__ or "" 

256 ) and not isinstance(mod.__loader__, type(self)): 

257 self._warn_already_imported(name) 

258 self._must_rewrite.update(names) 

259 self._marked_for_rewrite_cache.clear() 

260 

261 def _warn_already_imported(self, name: str) -> None: 

262 from _pytest.warning_types import PytestAssertRewriteWarning 

263 

264 self.config.issue_config_time_warning( 

265 PytestAssertRewriteWarning( 

266 "Module already imported so cannot be rewritten: %s" % name 

267 ), 

268 stacklevel=5, 

269 ) 

270 

271 def get_data(self, pathname: Union[str, bytes]) -> bytes: 

272 """Optional PEP302 get_data API.""" 

273 with open(pathname, "rb") as f: 

274 return f.read() 

275 

276 if sys.version_info >= (3, 10): 

277 

278 if sys.version_info >= (3, 12): 

279 from importlib.resources.abc import TraversableResources 

280 else: 

281 from importlib.abc import TraversableResources 

282 

283 def get_resource_reader(self, name: str) -> TraversableResources: # type: ignore 

284 if sys.version_info < (3, 11): 

285 from importlib.readers import FileReader 

286 else: 

287 from importlib.resources.readers import FileReader 

288 

289 return FileReader( # type:ignore[no-any-return] 

290 types.SimpleNamespace(path=self._rewritten_names[name]) 

291 ) 

292 

293 

294def _write_pyc_fp( 

295 fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType 

296) -> None: 

297 # Technically, we don't have to have the same pyc format as 

298 # (C)Python, since these "pycs" should never be seen by builtin 

299 # import. However, there's little reason to deviate. 

300 fp.write(importlib.util.MAGIC_NUMBER) 

301 # https://www.python.org/dev/peps/pep-0552/ 

302 flags = b"\x00\x00\x00\x00" 

303 fp.write(flags) 

304 # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) 

305 mtime = int(source_stat.st_mtime) & 0xFFFFFFFF 

306 size = source_stat.st_size & 0xFFFFFFFF 

307 # "<LL" stands for 2 unsigned longs, little-endian. 

308 fp.write(struct.pack("<LL", mtime, size)) 

309 fp.write(marshal.dumps(co)) 

310 

311 

312def _write_pyc( 

313 state: "AssertionState", 

314 co: types.CodeType, 

315 source_stat: os.stat_result, 

316 pyc: Path, 

317) -> bool: 

318 proc_pyc = f"{pyc}.{os.getpid()}" 

319 try: 

320 with open(proc_pyc, "wb") as fp: 

321 _write_pyc_fp(fp, source_stat, co) 

322 except OSError as e: 

323 state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}") 

324 return False 

325 

326 try: 

327 os.replace(proc_pyc, pyc) 

328 except OSError as e: 

329 state.trace(f"error writing pyc file at {pyc}: {e}") 

330 # we ignore any failure to write the cache file 

331 # there are many reasons, permission-denied, pycache dir being a 

332 # file etc. 

333 return False 

334 return True 

335 

336 

337def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: 

338 """Read and rewrite *fn* and return the code object.""" 

339 stat = os.stat(fn) 

340 source = fn.read_bytes() 

341 strfn = str(fn) 

342 tree = ast.parse(source, filename=strfn) 

343 rewrite_asserts(tree, source, strfn, config) 

344 co = compile(tree, strfn, "exec", dont_inherit=True) 

345 return stat, co 

346 

347 

348def _read_pyc( 

349 source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None 

350) -> Optional[types.CodeType]: 

351 """Possibly read a pytest pyc containing rewritten code. 

352 

353 Return rewritten code if successful or None if not. 

354 """ 

355 try: 

356 fp = open(pyc, "rb") 

357 except OSError: 

358 return None 

359 with fp: 

360 try: 

361 stat_result = os.stat(source) 

362 mtime = int(stat_result.st_mtime) 

363 size = stat_result.st_size 

364 data = fp.read(16) 

365 except OSError as e: 

366 trace(f"_read_pyc({source}): OSError {e}") 

367 return None 

368 # Check for invalid or out of date pyc file. 

369 if len(data) != (16): 

370 trace("_read_pyc(%s): invalid pyc (too short)" % source) 

371 return None 

372 if data[:4] != importlib.util.MAGIC_NUMBER: 

373 trace("_read_pyc(%s): invalid pyc (bad magic number)" % source) 

374 return None 

375 if data[4:8] != b"\x00\x00\x00\x00": 

376 trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source) 

377 return None 

378 mtime_data = data[8:12] 

379 if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: 

380 trace("_read_pyc(%s): out of date" % source) 

381 return None 

382 size_data = data[12:16] 

383 if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF: 

384 trace("_read_pyc(%s): invalid pyc (incorrect size)" % source) 

385 return None 

386 try: 

387 co = marshal.load(fp) 

388 except Exception as e: 

389 trace(f"_read_pyc({source}): marshal.load error {e}") 

390 return None 

391 if not isinstance(co, types.CodeType): 

392 trace("_read_pyc(%s): not a code object" % source) 

393 return None 

394 return co 

395 

396 

397def rewrite_asserts( 

398 mod: ast.Module, 

399 source: bytes, 

400 module_path: Optional[str] = None, 

401 config: Optional[Config] = None, 

402) -> None: 

403 """Rewrite the assert statements in mod.""" 

404 AssertionRewriter(module_path, config, source).run(mod) 

405 

406 

407def _saferepr(obj: object) -> str: 

408 r"""Get a safe repr of an object for assertion error messages. 

409 

410 The assertion formatting (util.format_explanation()) requires 

411 newlines to be escaped since they are a special character for it. 

412 Normally assertion.util.format_explanation() does this but for a 

413 custom repr it is possible to contain one of the special escape 

414 sequences, especially '\n{' and '\n}' are likely to be present in 

415 JSON reprs. 

416 """ 

417 maxsize = _get_maxsize_for_saferepr(util._config) 

418 return saferepr(obj, maxsize=maxsize).replace("\n", "\\n") 

419 

420 

421def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]: 

422 """Get `maxsize` configuration for saferepr based on the given config object.""" 

423 verbosity = config.getoption("verbose") if config is not None else 0 

424 if verbosity >= 2: 

425 return None 

426 if verbosity >= 1: 

427 return DEFAULT_REPR_MAX_SIZE * 10 

428 return DEFAULT_REPR_MAX_SIZE 

429 

430 

431def _format_assertmsg(obj: object) -> str: 

432 r"""Format the custom assertion message given. 

433 

434 For strings this simply replaces newlines with '\n~' so that 

435 util.format_explanation() will preserve them instead of escaping 

436 newlines. For other objects saferepr() is used first. 

437 """ 

438 # reprlib appears to have a bug which means that if a string 

439 # contains a newline it gets escaped, however if an object has a 

440 # .__repr__() which contains newlines it does not get escaped. 

441 # However in either case we want to preserve the newline. 

442 replaces = [("\n", "\n~"), ("%", "%%")] 

443 if not isinstance(obj, str): 

444 obj = saferepr(obj) 

445 replaces.append(("\\n", "\n~")) 

446 

447 for r1, r2 in replaces: 

448 obj = obj.replace(r1, r2) 

449 

450 return obj 

451 

452 

453def _should_repr_global_name(obj: object) -> bool: 

454 if callable(obj): 

455 return False 

456 

457 try: 

458 return not hasattr(obj, "__name__") 

459 except Exception: 

460 return True 

461 

462 

463def _format_boolop(explanations: Iterable[str], is_or: bool) -> str: 

464 explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" 

465 return explanation.replace("%", "%%") 

466 

467 

468def _call_reprcompare( 

469 ops: Sequence[str], 

470 results: Sequence[bool], 

471 expls: Sequence[str], 

472 each_obj: Sequence[object], 

473) -> str: 

474 for i, res, expl in zip(range(len(ops)), results, expls): 

475 try: 

476 done = not res 

477 except Exception: 

478 done = True 

479 if done: 

480 break 

481 if util._reprcompare is not None: 

482 custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) 

483 if custom is not None: 

484 return custom 

485 return expl 

486 

487 

488def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: 

489 if util._assertion_pass is not None: 

490 util._assertion_pass(lineno, orig, expl) 

491 

492 

493def _check_if_assertion_pass_impl() -> bool: 

494 """Check if any plugins implement the pytest_assertion_pass hook 

495 in order not to generate explanation unnecessarily (might be expensive).""" 

496 return True if util._assertion_pass else False 

497 

498 

499UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} 

500 

501BINOP_MAP = { 

502 ast.BitOr: "|", 

503 ast.BitXor: "^", 

504 ast.BitAnd: "&", 

505 ast.LShift: "<<", 

506 ast.RShift: ">>", 

507 ast.Add: "+", 

508 ast.Sub: "-", 

509 ast.Mult: "*", 

510 ast.Div: "/", 

511 ast.FloorDiv: "//", 

512 ast.Mod: "%%", # escaped for string formatting 

513 ast.Eq: "==", 

514 ast.NotEq: "!=", 

515 ast.Lt: "<", 

516 ast.LtE: "<=", 

517 ast.Gt: ">", 

518 ast.GtE: ">=", 

519 ast.Pow: "**", 

520 ast.Is: "is", 

521 ast.IsNot: "is not", 

522 ast.In: "in", 

523 ast.NotIn: "not in", 

524 ast.MatMult: "@", 

525} 

526 

527 

528def traverse_node(node: ast.AST) -> Iterator[ast.AST]: 

529 """Recursively yield node and all its children in depth-first order.""" 

530 yield node 

531 for child in ast.iter_child_nodes(node): 

532 yield from traverse_node(child) 

533 

534 

535@functools.lru_cache(maxsize=1) 

536def _get_assertion_exprs(src: bytes) -> Dict[int, str]: 

537 """Return a mapping from {lineno: "assertion test expression"}.""" 

538 ret: Dict[int, str] = {} 

539 

540 depth = 0 

541 lines: List[str] = [] 

542 assert_lineno: Optional[int] = None 

543 seen_lines: Set[int] = set() 

544 

545 def _write_and_reset() -> None: 

546 nonlocal depth, lines, assert_lineno, seen_lines 

547 assert assert_lineno is not None 

548 ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\") 

549 depth = 0 

550 lines = [] 

551 assert_lineno = None 

552 seen_lines = set() 

553 

554 tokens = tokenize.tokenize(io.BytesIO(src).readline) 

555 for tp, source, (lineno, offset), _, line in tokens: 

556 if tp == tokenize.NAME and source == "assert": 

557 assert_lineno = lineno 

558 elif assert_lineno is not None: 

559 # keep track of depth for the assert-message `,` lookup 

560 if tp == tokenize.OP and source in "([{": 

561 depth += 1 

562 elif tp == tokenize.OP and source in ")]}": 

563 depth -= 1 

564 

565 if not lines: 

566 lines.append(line[offset:]) 

567 seen_lines.add(lineno) 

568 # a non-nested comma separates the expression from the message 

569 elif depth == 0 and tp == tokenize.OP and source == ",": 

570 # one line assert with message 

571 if lineno in seen_lines and len(lines) == 1: 

572 offset_in_trimmed = offset + len(lines[-1]) - len(line) 

573 lines[-1] = lines[-1][:offset_in_trimmed] 

574 # multi-line assert with message 

575 elif lineno in seen_lines: 

576 lines[-1] = lines[-1][:offset] 

577 # multi line assert with escapd newline before message 

578 else: 

579 lines.append(line[:offset]) 

580 _write_and_reset() 

581 elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}: 

582 _write_and_reset() 

583 elif lines and lineno not in seen_lines: 

584 lines.append(line) 

585 seen_lines.add(lineno) 

586 

587 return ret 

588 

589 

590class AssertionRewriter(ast.NodeVisitor): 

591 """Assertion rewriting implementation. 

592 

593 The main entrypoint is to call .run() with an ast.Module instance, 

594 this will then find all the assert statements and rewrite them to 

595 provide intermediate values and a detailed assertion error. See 

596 http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html 

597 for an overview of how this works. 

598 

599 The entry point here is .run() which will iterate over all the 

600 statements in an ast.Module and for each ast.Assert statement it 

601 finds call .visit() with it. Then .visit_Assert() takes over and 

602 is responsible for creating new ast statements to replace the 

603 original assert statement: it rewrites the test of an assertion 

604 to provide intermediate values and replace it with an if statement 

605 which raises an assertion error with a detailed explanation in 

606 case the expression is false and calls pytest_assertion_pass hook 

607 if expression is true. 

608 

609 For this .visit_Assert() uses the visitor pattern to visit all the 

610 AST nodes of the ast.Assert.test field, each visit call returning 

611 an AST node and the corresponding explanation string. During this 

612 state is kept in several instance attributes: 

613 

614 :statements: All the AST statements which will replace the assert 

615 statement. 

616 

617 :variables: This is populated by .variable() with each variable 

618 used by the statements so that they can all be set to None at 

619 the end of the statements. 

620 

621 :variable_counter: Counter to create new unique variables needed 

622 by statements. Variables are created using .variable() and 

623 have the form of "@py_assert0". 

624 

625 :expl_stmts: The AST statements which will be executed to get 

626 data from the assertion. This is the code which will construct 

627 the detailed assertion message that is used in the AssertionError 

628 or for the pytest_assertion_pass hook. 

629 

630 :explanation_specifiers: A dict filled by .explanation_param() 

631 with %-formatting placeholders and their corresponding 

632 expressions to use in the building of an assertion message. 

633 This is used by .pop_format_context() to build a message. 

634 

635 :stack: A stack of the explanation_specifiers dicts maintained by 

636 .push_format_context() and .pop_format_context() which allows 

637 to build another %-formatted string while already building one. 

638 

639 This state is reset on every new assert statement visited and used 

640 by the other visitors. 

641 """ 

642 

643 def __init__( 

644 self, module_path: Optional[str], config: Optional[Config], source: bytes 

645 ) -> None: 

646 super().__init__() 

647 self.module_path = module_path 

648 self.config = config 

649 if config is not None: 

650 self.enable_assertion_pass_hook = config.getini( 

651 "enable_assertion_pass_hook" 

652 ) 

653 else: 

654 self.enable_assertion_pass_hook = False 

655 self.source = source 

656 

657 def run(self, mod: ast.Module) -> None: 

658 """Find all assert statements in *mod* and rewrite them.""" 

659 if not mod.body: 

660 # Nothing to do. 

661 return 

662 

663 # We'll insert some special imports at the top of the module, but after any 

664 # docstrings and __future__ imports, so first figure out where that is. 

665 doc = getattr(mod, "docstring", None) 

666 expect_docstring = doc is None 

667 if doc is not None and self.is_rewrite_disabled(doc): 

668 return 

669 pos = 0 

670 lineno = 1 

671 for item in mod.body: 

672 if ( 

673 expect_docstring 

674 and isinstance(item, ast.Expr) 

675 and isinstance(item.value, ast.Str) 

676 ): 

677 doc = item.value.s 

678 if self.is_rewrite_disabled(doc): 

679 return 

680 expect_docstring = False 

681 elif ( 

682 isinstance(item, ast.ImportFrom) 

683 and item.level == 0 

684 and item.module == "__future__" 

685 ): 

686 pass 

687 else: 

688 break 

689 pos += 1 

690 # Special case: for a decorated function, set the lineno to that of the 

691 # first decorator, not the `def`. Issue #4984. 

692 if isinstance(item, ast.FunctionDef) and item.decorator_list: 

693 lineno = item.decorator_list[0].lineno 

694 else: 

695 lineno = item.lineno 

696 # Now actually insert the special imports. 

697 if sys.version_info >= (3, 10): 

698 aliases = [ 

699 ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), 

700 ast.alias( 

701 "_pytest.assertion.rewrite", 

702 "@pytest_ar", 

703 lineno=lineno, 

704 col_offset=0, 

705 ), 

706 ] 

707 else: 

708 aliases = [ 

709 ast.alias("builtins", "@py_builtins"), 

710 ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), 

711 ] 

712 imports = [ 

713 ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases 

714 ] 

715 mod.body[pos:pos] = imports 

716 

717 # Collect asserts. 

718 nodes: List[ast.AST] = [mod] 

719 while nodes: 

720 node = nodes.pop() 

721 for name, field in ast.iter_fields(node): 

722 if isinstance(field, list): 

723 new: List[ast.AST] = [] 

724 for i, child in enumerate(field): 

725 if isinstance(child, ast.Assert): 

726 # Transform assert. 

727 new.extend(self.visit(child)) 

728 else: 

729 new.append(child) 

730 if isinstance(child, ast.AST): 

731 nodes.append(child) 

732 setattr(node, name, new) 

733 elif ( 

734 isinstance(field, ast.AST) 

735 # Don't recurse into expressions as they can't contain 

736 # asserts. 

737 and not isinstance(field, ast.expr) 

738 ): 

739 nodes.append(field) 

740 

741 @staticmethod 

742 def is_rewrite_disabled(docstring: str) -> bool: 

743 return "PYTEST_DONT_REWRITE" in docstring 

744 

745 def variable(self) -> str: 

746 """Get a new variable.""" 

747 # Use a character invalid in python identifiers to avoid clashing. 

748 name = "@py_assert" + str(next(self.variable_counter)) 

749 self.variables.append(name) 

750 return name 

751 

752 def assign(self, expr: ast.expr) -> ast.Name: 

753 """Give *expr* a name.""" 

754 name = self.variable() 

755 self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) 

756 return ast.Name(name, ast.Load()) 

757 

758 def display(self, expr: ast.expr) -> ast.expr: 

759 """Call saferepr on the expression.""" 

760 return self.helper("_saferepr", expr) 

761 

762 def helper(self, name: str, *args: ast.expr) -> ast.expr: 

763 """Call a helper in this module.""" 

764 py_name = ast.Name("@pytest_ar", ast.Load()) 

765 attr = ast.Attribute(py_name, name, ast.Load()) 

766 return ast.Call(attr, list(args), []) 

767 

768 def builtin(self, name: str) -> ast.Attribute: 

769 """Return the builtin called *name*.""" 

770 builtin_name = ast.Name("@py_builtins", ast.Load()) 

771 return ast.Attribute(builtin_name, name, ast.Load()) 

772 

773 def explanation_param(self, expr: ast.expr) -> str: 

774 """Return a new named %-formatting placeholder for expr. 

775 

776 This creates a %-formatting placeholder for expr in the 

777 current formatting context, e.g. ``%(py0)s``. The placeholder 

778 and expr are placed in the current format context so that it 

779 can be used on the next call to .pop_format_context(). 

780 """ 

781 specifier = "py" + str(next(self.variable_counter)) 

782 self.explanation_specifiers[specifier] = expr 

783 return "%(" + specifier + ")s" 

784 

785 def push_format_context(self) -> None: 

786 """Create a new formatting context. 

787 

788 The format context is used for when an explanation wants to 

789 have a variable value formatted in the assertion message. In 

790 this case the value required can be added using 

791 .explanation_param(). Finally .pop_format_context() is used 

792 to format a string of %-formatted values as added by 

793 .explanation_param(). 

794 """ 

795 self.explanation_specifiers: Dict[str, ast.expr] = {} 

796 self.stack.append(self.explanation_specifiers) 

797 

798 def pop_format_context(self, expl_expr: ast.expr) -> ast.Name: 

799 """Format the %-formatted string with current format context. 

800 

801 The expl_expr should be an str ast.expr instance constructed from 

802 the %-placeholders created by .explanation_param(). This will 

803 add the required code to format said string to .expl_stmts and 

804 return the ast.Name instance of the formatted string. 

805 """ 

806 current = self.stack.pop() 

807 if self.stack: 

808 self.explanation_specifiers = self.stack[-1] 

809 keys = [ast.Str(key) for key in current.keys()] 

810 format_dict = ast.Dict(keys, list(current.values())) 

811 form = ast.BinOp(expl_expr, ast.Mod(), format_dict) 

812 name = "@py_format" + str(next(self.variable_counter)) 

813 if self.enable_assertion_pass_hook: 

814 self.format_variables.append(name) 

815 self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) 

816 return ast.Name(name, ast.Load()) 

817 

818 def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]: 

819 """Handle expressions we don't have custom code for.""" 

820 assert isinstance(node, ast.expr) 

821 res = self.assign(node) 

822 return res, self.explanation_param(self.display(res)) 

823 

824 def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: 

825 """Return the AST statements to replace the ast.Assert instance. 

826 

827 This rewrites the test of an assertion to provide 

828 intermediate values and replace it with an if statement which 

829 raises an assertion error with a detailed explanation in case 

830 the expression is false. 

831 """ 

832 if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: 

833 from _pytest.warning_types import PytestAssertRewriteWarning 

834 import warnings 

835 

836 # TODO: This assert should not be needed. 

837 assert self.module_path is not None 

838 warnings.warn_explicit( 

839 PytestAssertRewriteWarning( 

840 "assertion is always true, perhaps remove parentheses?" 

841 ), 

842 category=None, 

843 filename=self.module_path, 

844 lineno=assert_.lineno, 

845 ) 

846 

847 self.statements: List[ast.stmt] = [] 

848 self.variables: List[str] = [] 

849 self.variable_counter = itertools.count() 

850 

851 if self.enable_assertion_pass_hook: 

852 self.format_variables: List[str] = [] 

853 

854 self.stack: List[Dict[str, ast.expr]] = [] 

855 self.expl_stmts: List[ast.stmt] = [] 

856 self.push_format_context() 

857 # Rewrite assert into a bunch of statements. 

858 top_condition, explanation = self.visit(assert_.test) 

859 

860 negation = ast.UnaryOp(ast.Not(), top_condition) 

861 

862 if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook 

863 msg = self.pop_format_context(ast.Str(explanation)) 

864 

865 # Failed 

866 if assert_.msg: 

867 assertmsg = self.helper("_format_assertmsg", assert_.msg) 

868 gluestr = "\n>assert " 

869 else: 

870 assertmsg = ast.Str("") 

871 gluestr = "assert " 

872 err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) 

873 err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) 

874 err_name = ast.Name("AssertionError", ast.Load()) 

875 fmt = self.helper("_format_explanation", err_msg) 

876 exc = ast.Call(err_name, [fmt], []) 

877 raise_ = ast.Raise(exc, None) 

878 statements_fail = [] 

879 statements_fail.extend(self.expl_stmts) 

880 statements_fail.append(raise_) 

881 

882 # Passed 

883 fmt_pass = self.helper("_format_explanation", msg) 

884 orig = _get_assertion_exprs(self.source)[assert_.lineno] 

885 hook_call_pass = ast.Expr( 

886 self.helper( 

887 "_call_assertion_pass", 

888 ast.Num(assert_.lineno), 

889 ast.Str(orig), 

890 fmt_pass, 

891 ) 

892 ) 

893 # If any hooks implement assert_pass hook 

894 hook_impl_test = ast.If( 

895 self.helper("_check_if_assertion_pass_impl"), 

896 self.expl_stmts + [hook_call_pass], 

897 [], 

898 ) 

899 statements_pass = [hook_impl_test] 

900 

901 # Test for assertion condition 

902 main_test = ast.If(negation, statements_fail, statements_pass) 

903 self.statements.append(main_test) 

904 if self.format_variables: 

905 variables = [ 

906 ast.Name(name, ast.Store()) for name in self.format_variables 

907 ] 

908 clear_format = ast.Assign(variables, ast.NameConstant(None)) 

909 self.statements.append(clear_format) 

910 

911 else: # Original assertion rewriting 

912 # Create failure message. 

913 body = self.expl_stmts 

914 self.statements.append(ast.If(negation, body, [])) 

915 if assert_.msg: 

916 assertmsg = self.helper("_format_assertmsg", assert_.msg) 

917 explanation = "\n>assert " + explanation 

918 else: 

919 assertmsg = ast.Str("") 

920 explanation = "assert " + explanation 

921 template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) 

922 msg = self.pop_format_context(template) 

923 fmt = self.helper("_format_explanation", msg) 

924 err_name = ast.Name("AssertionError", ast.Load()) 

925 exc = ast.Call(err_name, [fmt], []) 

926 raise_ = ast.Raise(exc, None) 

927 

928 body.append(raise_) 

929 

930 # Clear temporary variables by setting them to None. 

931 if self.variables: 

932 variables = [ast.Name(name, ast.Store()) for name in self.variables] 

933 clear = ast.Assign(variables, ast.NameConstant(None)) 

934 self.statements.append(clear) 

935 # Fix locations (line numbers/column offsets). 

936 for stmt in self.statements: 

937 for node in traverse_node(stmt): 

938 ast.copy_location(node, assert_) 

939 return self.statements 

940 

941 def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: 

942 # Display the repr of the name if it's a local variable or 

943 # _should_repr_global_name() thinks it's acceptable. 

944 locs = ast.Call(self.builtin("locals"), [], []) 

945 inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs]) 

946 dorepr = self.helper("_should_repr_global_name", name) 

947 test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) 

948 expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) 

949 return name, self.explanation_param(expr) 

950 

951 def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: 

952 res_var = self.variable() 

953 expl_list = self.assign(ast.List([], ast.Load())) 

954 app = ast.Attribute(expl_list, "append", ast.Load()) 

955 is_or = int(isinstance(boolop.op, ast.Or)) 

956 body = save = self.statements 

957 fail_save = self.expl_stmts 

958 levels = len(boolop.values) - 1 

959 self.push_format_context() 

960 # Process each operand, short-circuiting if needed. 

961 for i, v in enumerate(boolop.values): 

962 if i: 

963 fail_inner: List[ast.stmt] = [] 

964 # cond is set in a prior loop iteration below 

965 self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa 

966 self.expl_stmts = fail_inner 

967 self.push_format_context() 

968 res, expl = self.visit(v) 

969 body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) 

970 expl_format = self.pop_format_context(ast.Str(expl)) 

971 call = ast.Call(app, [expl_format], []) 

972 self.expl_stmts.append(ast.Expr(call)) 

973 if i < levels: 

974 cond: ast.expr = res 

975 if is_or: 

976 cond = ast.UnaryOp(ast.Not(), cond) 

977 inner: List[ast.stmt] = [] 

978 self.statements.append(ast.If(cond, inner, [])) 

979 self.statements = body = inner 

980 self.statements = save 

981 self.expl_stmts = fail_save 

982 expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) 

983 expl = self.pop_format_context(expl_template) 

984 return ast.Name(res_var, ast.Load()), self.explanation_param(expl) 

985 

986 def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]: 

987 pattern = UNARY_MAP[unary.op.__class__] 

988 operand_res, operand_expl = self.visit(unary.operand) 

989 res = self.assign(ast.UnaryOp(unary.op, operand_res)) 

990 return res, pattern % (operand_expl,) 

991 

992 def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]: 

993 symbol = BINOP_MAP[binop.op.__class__] 

994 left_expr, left_expl = self.visit(binop.left) 

995 right_expr, right_expl = self.visit(binop.right) 

996 explanation = f"({left_expl} {symbol} {right_expl})" 

997 res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) 

998 return res, explanation 

999 

1000 def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: 

1001 new_func, func_expl = self.visit(call.func) 

1002 arg_expls = [] 

1003 new_args = [] 

1004 new_kwargs = [] 

1005 for arg in call.args: 

1006 res, expl = self.visit(arg) 

1007 arg_expls.append(expl) 

1008 new_args.append(res) 

1009 for keyword in call.keywords: 

1010 res, expl = self.visit(keyword.value) 

1011 new_kwargs.append(ast.keyword(keyword.arg, res)) 

1012 if keyword.arg: 

1013 arg_expls.append(keyword.arg + "=" + expl) 

1014 else: # **args have `arg` keywords with an .arg of None 

1015 arg_expls.append("**" + expl) 

1016 

1017 expl = "{}({})".format(func_expl, ", ".join(arg_expls)) 

1018 new_call = ast.Call(new_func, new_args, new_kwargs) 

1019 res = self.assign(new_call) 

1020 res_expl = self.explanation_param(self.display(res)) 

1021 outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}" 

1022 return res, outer_expl 

1023 

1024 def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]: 

1025 # A Starred node can appear in a function call. 

1026 res, expl = self.visit(starred.value) 

1027 new_starred = ast.Starred(res, starred.ctx) 

1028 return new_starred, "*" + expl 

1029 

1030 def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: 

1031 if not isinstance(attr.ctx, ast.Load): 

1032 return self.generic_visit(attr) 

1033 value, value_expl = self.visit(attr.value) 

1034 res = self.assign(ast.Attribute(value, attr.attr, ast.Load())) 

1035 res_expl = self.explanation_param(self.display(res)) 

1036 pat = "%s\n{%s = %s.%s\n}" 

1037 expl = pat % (res_expl, res_expl, value_expl, attr.attr) 

1038 return res, expl 

1039 

1040 def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: 

1041 self.push_format_context() 

1042 left_res, left_expl = self.visit(comp.left) 

1043 if isinstance(comp.left, (ast.Compare, ast.BoolOp)): 

1044 left_expl = f"({left_expl})" 

1045 res_variables = [self.variable() for i in range(len(comp.ops))] 

1046 load_names = [ast.Name(v, ast.Load()) for v in res_variables] 

1047 store_names = [ast.Name(v, ast.Store()) for v in res_variables] 

1048 it = zip(range(len(comp.ops)), comp.ops, comp.comparators) 

1049 expls = [] 

1050 syms = [] 

1051 results = [left_res] 

1052 for i, op, next_operand in it: 

1053 next_res, next_expl = self.visit(next_operand) 

1054 if isinstance(next_operand, (ast.Compare, ast.BoolOp)): 

1055 next_expl = f"({next_expl})" 

1056 results.append(next_res) 

1057 sym = BINOP_MAP[op.__class__] 

1058 syms.append(ast.Str(sym)) 

1059 expl = f"{left_expl} {sym} {next_expl}" 

1060 expls.append(ast.Str(expl)) 

1061 res_expr = ast.Compare(left_res, [op], [next_res]) 

1062 self.statements.append(ast.Assign([store_names[i]], res_expr)) 

1063 left_res, left_expl = next_res, next_expl 

1064 # Use pytest.assertion.util._reprcompare if that's available. 

1065 expl_call = self.helper( 

1066 "_call_reprcompare", 

1067 ast.Tuple(syms, ast.Load()), 

1068 ast.Tuple(load_names, ast.Load()), 

1069 ast.Tuple(expls, ast.Load()), 

1070 ast.Tuple(results, ast.Load()), 

1071 ) 

1072 if len(comp.ops) > 1: 

1073 res: ast.expr = ast.BoolOp(ast.And(), load_names) 

1074 else: 

1075 res = load_names[0] 

1076 return res, self.explanation_param(self.pop_format_context(expl_call)) 

1077 

1078 

1079def try_makedirs(cache_dir: Path) -> bool: 

1080 """Attempt to create the given directory and sub-directories exist. 

1081 

1082 Returns True if successful or if it already exists. 

1083 """ 

1084 try: 

1085 os.makedirs(cache_dir, exist_ok=True) 

1086 except (FileNotFoundError, NotADirectoryError, FileExistsError): 

1087 # One of the path components was not a directory: 

1088 # - we're in a zip file 

1089 # - it is a file 

1090 return False 

1091 except PermissionError: 

1092 return False 

1093 except OSError as e: 

1094 # as of now, EROFS doesn't have an equivalent OSError-subclass 

1095 if e.errno == errno.EROFS: 

1096 return False 

1097 raise 

1098 return True 

1099 

1100 

1101def get_cache_dir(file_path: Path) -> Path: 

1102 """Return the cache directory to write .pyc files for the given .py file path.""" 

1103 if sys.version_info >= (3, 8) and sys.pycache_prefix: 

1104 # given: 

1105 # prefix = '/tmp/pycs' 

1106 # path = '/home/user/proj/test_app.py' 

1107 # we want: 

1108 # '/tmp/pycs/home/user/proj' 

1109 return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) 

1110 else: 

1111 # classic pycache directory 

1112 return file_path.parent / "__pycache__"