Coverage for src/fastoai/models/_codegen.py: 99%

157 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-12-06 09:34 +0800

1import ast 

2import inspect 

3from collections import defaultdict 

4from pathlib import Path 

5from types import ModuleType 

6from typing import cast 

7 

8from pydantic.alias_generators import to_snake 

9 

10ID_PREFIXES = { 

11 "FileObject": "file-", 

12 "Assistant": "asst_", 

13 "RunStep": "step_", 

14 "Message": "msg_", 

15} 

16 

17CLS_TO_TABLE = { 

18 "FileObject": "file", 

19 "RunStep": "step", 

20} 

21TABLE_TO_CLS = { 

22 "file": "FileObject", 

23 "step": "RunStep", 

24} 

25 

26 

27def ruff_check(path: Path): 

28 import os 

29 import subprocess 

30 

31 from ruff.__main__ import find_ruff_bin 

32 

33 ruff = os.fsdecode(find_ruff_bin()) 

34 argv = ["check", "--fix", path] 

35 

36 subprocess.run([ruff, *argv]) 

37 

38 

39def Annotated(type_, *args) -> ast.Subscript: 

40 return ast.Subscript( 

41 value=ast.Name(id="Annotated", ctx=ast.Load()), 

42 slice=ast.Tuple( 

43 elts=[type_, *args], 

44 ctx=ast.Load(), 

45 ), 

46 ctx=ast.Load(), 

47 ) 

48 

49 

50def Field(*args, **kwargs) -> ast.Call: 

51 return ast.Call( 

52 func=ast.Name(id="Field", ctx=ast.Load()), 

53 args=list(args), 

54 keywords=[ast.keyword(arg=k, value=v) for k, v in kwargs.items()], 

55 ) 

56 

57 

58def _fix_id(class_def: ast.ClassDef): 

59 id_field = cast(ast.AnnAssign, class_def.body[0]) 

60 id_field.annotation = Annotated( 

61 id_field.annotation, 

62 Field( 

63 primary_key=ast.Constant(value=True), 

64 default_factory=ast.Call( 

65 func=ast.Name(id="random_id_with_prefix", ctx=ast.Load()), 

66 args=[ 

67 ast.Constant( 

68 value=ID_PREFIXES.get( 

69 class_def.name, class_def.name.lower() + "_" 

70 ) 

71 ) 

72 ], 

73 keywords=[], 

74 ), 

75 ), 

76 ) 

77 

78 

79def _fix_literal(class_def: ast.ClassDef): 

80 for n in class_def.body: 

81 if ( 

82 isinstance(n, ast.AnnAssign) 

83 and isinstance(n.annotation, ast.Subscript) 

84 and isinstance(n.annotation.value, ast.Name) 

85 and n.annotation.value.id == "Literal" 

86 ): 

87 t = n.annotation 

88 f = Field( 

89 sa_type=ast.Call( 

90 func=ast.Name(id="Enum", ctx=ast.Load()), 

91 args=[ 

92 elt 

93 for elt in cast(ast.Tuple, cast(ast.Subscript, t).slice).elts 

94 ], 

95 keywords=[], 

96 ) 

97 ) 

98 n.annotation = Annotated(t, f) 

99 

100 

101def _get_object_literal(n: ast.ClassDef) -> str: 

102 object_field = next( 

103 ( 

104 a 

105 for a in n.body 

106 if isinstance(a, ast.AnnAssign) 

107 and isinstance(a.target, ast.Name) 

108 and a.target.id == "object" 

109 ) 

110 ) 

111 n.body.remove(object_field) 

112 return cast( 

113 str, 

114 cast(ast.Constant, cast(ast.Subscript, object_field.annotation).slice).value, 

115 ) 

116 

117 

118def _fix_optional(class_def: ast.ClassDef): 

119 for stmt in class_def.body: 

120 if ( 

121 isinstance(stmt, ast.AnnAssign) 

122 and isinstance(ann := stmt.annotation, ast.Subscript) 

123 and isinstance(ann.value, ast.Name) 

124 and ann.value.id == "Optional" 

125 ): 

126 if isinstance(ann.slice, ast.Name) and ann.slice.id in { 

127 "str", 

128 "int", 

129 "float", 

130 "bool", 

131 }: 

132 stmt.annotation = ast.BinOp( 

133 left=ann.slice, 

134 op=ast.BitOr(), 

135 right=ast.Name(id="None", ctx=ast.Load()), 

136 ) 

137 else: 

138 if ( 

139 isinstance(ann.slice, ast.Subscript) 

140 and isinstance(ann.slice.value, ast.Name) 

141 and ann.slice.value.id == "List" 

142 ): 

143 ann.slice.value.id = "list" 

144 t = ast.BinOp( 

145 left=ann.slice, 

146 op=ast.BitOr(), 

147 right=ast.Name(id="None", ctx=ast.Load()), 

148 ) 

149 as_sa_type_call = ast.Call( 

150 func=ast.Name(id="as_sa_type", ctx=ast.Load()), 

151 args=[ann.slice], 

152 keywords=[], 

153 ) 

154 f = Field( 

155 sa_type=as_sa_type_call, 

156 nullable=ast.Constant(value=True), 

157 ) 

158 stmt.annotation = Annotated(t, f) 

159 

160 

161def _fix_list(class_def: ast.ClassDef): 

162 for stmt in class_def.body: 

163 if ( 

164 isinstance(ann_assign := stmt, ast.AnnAssign) 

165 and isinstance(ann := ann_assign.annotation, ast.Subscript) 

166 and isinstance(ann.value, ast.Name) 

167 ): 

168 if ann.value.id == "List": 

169 ann.value.id = "list" 

170 as_sa_type_call = ast.Call( 

171 func=ast.Name(id="as_sa_type", ctx=ast.Load()), 

172 args=[ann], 

173 keywords=[], 

174 ) 

175 ann_assign.annotation = Annotated( 

176 ann, 

177 Field( 

178 default_factory=ast.Name("list", ctx=ast.Load()), 

179 sa_type=as_sa_type_call, 

180 ), 

181 ) 

182 

183 

184def _fix_name(class_def: ast.ClassDef): 

185 for stmt in class_def.body: 

186 if ( 

187 isinstance(ann_assign := stmt, ast.AnnAssign) 

188 and isinstance(ann := ann_assign.annotation, ast.Name) 

189 and ann.id not in {"str", "int", "float", "bool"} 

190 ): 

191 as_sa_type_call = ast.Call( 

192 func=ast.Name(id="as_sa_type", ctx=ast.Load()), 

193 args=[ann], 

194 keywords=[], 

195 ) 

196 ann_assign.annotation = Annotated( 

197 ann, 

198 Field( 

199 sa_type=as_sa_type_call, 

200 ), 

201 ) 

202 

203 

204def Optional(type_) -> ast.BinOp: 

205 return ast.BinOp( 

206 left=type_, 

207 op=ast.BitOr(), 

208 right=ast.Name(id="None", ctx=ast.Load()), 

209 ) 

210 

211 

212def _fix_timestamp(class_def: ast.ClassDef): 

213 timestamps = [] 

214 optional = False 

215 for stmt in class_def.body: 

216 if ( 

217 isinstance(ann_assign := stmt, ast.AnnAssign) 

218 and isinstance(ann_assign.target, ast.Name) 

219 and ann_assign.target.id.endswith("_at") 

220 ): 

221 if isinstance(ann_assign.annotation, ast.Name): 

222 ann_assign.annotation = Annotated( 

223 ast.Name(id="datetime", ctx=ast.Load()), 

224 Field(default_factory=ast.Name(id="now", ctx=ast.Load())), 

225 ) 

226 else: 

227 ann_assign.annotation = Optional( 

228 ast.Name(id="datetime", ctx=ast.Load()) 

229 ) 

230 optional = True 

231 timestamps.append(ann_assign.target.id) 

232 if timestamps: 

233 name_int = ast.Name(id="int", ctx=ast.Load()) 

234 name_datetime = ast.Name(id="datetime", ctx=ast.Load()) 

235 serializor = ast.FunctionDef( 

236 name="serialize_datetime", 

237 args=ast.arguments( 

238 posonlyargs=[], 

239 args=[ 

240 ast.arg("self"), 

241 ast.arg( 

242 arg="dt", 

243 annotation=Optional(name_datetime) 

244 if optional 

245 else name_datetime, 

246 ), 

247 ], 

248 kwonlyargs=[], 

249 kw_defaults=[], 

250 defaults=[], 

251 ), 

252 body=[ 

253 *( 

254 [ 

255 ast.If( 

256 test=ast.Compare( 

257 left=ast.Name(id="dt", ctx=ast.Load()), 

258 ops=[ast.Is()], 

259 comparators=[ast.Name(id="None", ctx=ast.Load())], 

260 ), 

261 body=[ast.Return(value=ast.Constant(value=None))], 

262 orelse=[], 

263 ) 

264 ] 

265 if optional 

266 else [] 

267 ), 

268 ast.Return( 

269 value=ast.Call( 

270 func=name_int, 

271 args=[ 

272 ast.Call( 

273 func=ast.Attribute( 

274 value=ast.Name(id="dt", ctx=ast.Load()), 

275 attr="timestamp", 

276 ctx=ast.Load(), 

277 ), 

278 args=[], 

279 keywords=[], 

280 ) 

281 ], 

282 keywords=[], 

283 ) 

284 ), 

285 ], 

286 decorator_list=[ 

287 ast.Call( 

288 func=ast.Name(id="field_serializer", ctx=ast.Load()), 

289 args=[ast.Constant(value=f) for f in timestamps], 

290 keywords=[], 

291 ) 

292 ], 

293 returns=Optional(name_int) if optional else name_int, 

294 type_params=[], 

295 ) 

296 class_def.body.append(serializor) 

297 

298 

299def _to_openai_model( 

300 openai_model_name: str, object_literal: str 

301) -> ast.AsyncFunctionDef: 

302 return ast.AsyncFunctionDef( 

303 name="to_openai_model", 

304 args=ast.arguments( 

305 posonlyargs=[], 

306 args=[ast.arg(arg="self")], 

307 kwonlyargs=[], 

308 kw_defaults=[], 

309 defaults=[], 

310 ), 

311 body=[ 

312 ast.Assign( 

313 targets=[ast.Name(id="value", ctx=ast.Store())], 

314 value=ast.Call( 

315 func=ast.Attribute( 

316 value=ast.Name(id="self", ctx=ast.Load()), 

317 attr="model_dump", 

318 ctx=ast.Load(), 

319 ), 

320 args=[], 

321 keywords=[ast.keyword("by_alias", ast.Constant(True))], 

322 ), 

323 ), 

324 ast.Assign( 

325 targets=[ 

326 ast.Subscript( 

327 value=ast.Name(id="value", ctx=ast.Load()), 

328 slice=ast.Constant(value="object"), 

329 ctx=ast.Store(), 

330 ) 

331 ], 

332 value=ast.Constant(value=object_literal), 

333 ), 

334 ast.Return( 

335 value=ast.Call( 

336 func=ast.Attribute( 

337 value=ast.Name(id=openai_model_name, ctx=ast.Load()), 

338 attr="model_validate", 

339 ctx=ast.Load(), 

340 ), 

341 args=[ast.Name(id="value", ctx=ast.Load())], 

342 keywords=[], 

343 ) 

344 ), 

345 ], 

346 decorator_list=[], 

347 returns=ast.Name(id=openai_model_name, ctx=ast.Load()), 

348 type_params=[], 

349 ) 

350 

351 

352def generate_module(module: ModuleType) -> ast.Module: 

353 tree = ast.parse(inspect.getsource(module)) 

354 all_defs = next( 

355 node 

356 for node in tree.body 

357 if isinstance(node, ast.Assign) 

358 and len(node.targets) == 1 

359 and isinstance(node.targets[0], ast.Name) 

360 and node.targets[0].id == "__all__" 

361 ) 

362 all_defs_index = tree.body.index(all_defs) 

363 imports = cast(list[ast.ImportFrom], tree.body[:all_defs_index]) 

364 body = tree.body[all_defs_index + 1 :] 

365 for n in imports: 

366 match n: 

367 case n if n.module == "typing_extensions": 

368 n.module = "typing" 

369 case n if n.level != 0: 

370 n.module = ".".join( 

371 (module.__name__.split("."))[: -n.level] + [n.module or ""] 

372 ) 

373 n.level = 0 

374 imports.extend( 

375 [ 

376 ast.ImportFrom("datetime", [ast.alias("datetime")], 0), 

377 ast.ImportFrom("typing", [ast.alias("Annotated")], 0), 

378 ast.ImportFrom("sqlalchemy.ext.asyncio", [ast.alias("AsyncAttrs")], 0), 

379 ast.ImportFrom( 

380 "sqlmodel", 

381 [ 

382 ast.alias("SQLModel"), 

383 ast.alias("Enum"), 

384 ast.alias("Field"), 

385 ast.alias("Relationship"), 

386 ], 

387 0, 

388 ), 

389 ast.ImportFrom("pydantic", [ast.alias("field_serializer")], 0), 

390 ast.ImportFrom("_metadata", [ast.alias("WithMetadata")], 2), 

391 ast.ImportFrom("_types", [ast.alias("as_sa_type")], 2), 

392 ast.ImportFrom( 

393 "_utils", [ast.alias("now"), ast.alias("random_id_with_prefix")], 2 

394 ), 

395 ] 

396 ) 

397 class_def = cast(ast.ClassDef, body.pop()) 

398 

399 class_def.body = [ 

400 stmt for stmt in class_def.body if isinstance(stmt, ast.AnnAssign) 

401 ] 

402 metadata_field = next( 

403 ( 

404 a 

405 for a in class_def.body 

406 if isinstance(a, ast.AnnAssign) 

407 and isinstance(a.target, ast.Name) 

408 and a.target.id == "metadata" 

409 ), 

410 None, 

411 ) 

412 if metadata_field: 

413 class_def.body.remove(metadata_field) 

414 class_def.bases = [ 

415 ast.Name(id="AsyncAttrs", ctx=ast.Load()), 

416 ast.Name( 

417 id="WithMetadata" if metadata_field else "SQLModel", 

418 ctx=ast.Load(), 

419 ), 

420 ] 

421 class_def.keywords = [ast.keyword(arg="table", value=ast.Constant(value=True))] 

422 _fix_id(class_def) 

423 if tablename := CLS_TO_TABLE.get(class_def.name): 

424 class_def.body.insert( 

425 0, 

426 ast.Assign( 

427 targets=[ast.Name(id="__tablename__", ctx=ast.Store())], 

428 value=ast.Constant(value=tablename), 

429 ), 

430 ) 

431 imports.append( 

432 ast.ImportFrom( 

433 module.__name__, 

434 [ast.alias(class_def.name, f"_{class_def.name}")], 

435 0, 

436 ) 

437 ) 

438 object_literal = _get_object_literal(class_def) 

439 class_def.body.append(_to_openai_model(f"_{class_def.name}", object_literal)) 

440 _fix_literal(class_def) 

441 _fix_optional(class_def) 

442 _fix_list(class_def) 

443 _fix_name(class_def) 

444 _fix_timestamp(class_def) 

445 for node in body: 

446 if isinstance(node, ast.ClassDef): 

447 imports.append( 

448 ast.ImportFrom( 

449 module.__name__, 

450 [ast.alias(node.name)], 

451 0, 

452 ) 

453 ) 

454 elif isinstance(node, ast.AnnAssign): 

455 if isinstance(node.annotation, ast.Subscript): 

456 imports.append( 

457 ast.ImportFrom( 

458 module.__name__, 

459 [ast.alias(cast(ast.Name, node.annotation.value).id)], 

460 0, 

461 ) 

462 ) 

463 elif ( 

464 isinstance(node.annotation, ast.Name) 

465 and node.annotation.id == "TypeAlias" 

466 ): 

467 imports.append( 

468 ast.ImportFrom( 

469 module.__name__, 

470 [ast.alias(cast(ast.Name, node.target).id)], 

471 0, 

472 ) 

473 ) 

474 tree.body = imports + tree.body[-1:] 

475 return tree 

476 

477 

478def Relationship(*, back_populates: str) -> ast.Call: 

479 return ast.Call( 

480 func=ast.Name(id="Relationship", ctx=ast.Load()), 

481 args=[], 

482 keywords=[ 

483 ast.keyword(arg="back_populates", value=ast.Constant(value=back_populates)) 

484 ], 

485 ) 

486 

487 

488def _add_foreign_key(module: ast.Module, back_populates: dict[str, list[str]]): 

489 class_def = cast(ast.ClassDef, module.body[-1]) 

490 for stmt in class_def.body: 

491 if ( 

492 isinstance(ann_assign := stmt, ast.AnnAssign) 

493 and isinstance(field_name := ann_assign.target, ast.Name) 

494 and field_name.id.endswith("_id") 

495 ): 

496 table = field_name.id[:-3] 

497 f = Field(foreign_key=ast.Constant(value=f"{table}.id")) 

498 optional = False 

499 if isinstance(ann_assign.annotation, ast.BinOp): 

500 optional = True 

501 f.keywords.append( 

502 ast.keyword(arg="nullable", value=ast.Constant(value=True)) 

503 ) 

504 ann_assign.annotation = Annotated( 

505 ann_assign.annotation, 

506 f, 

507 ) 

508 module.body.insert( 

509 0, ast.ImportFrom(to_snake(cls := _t2c(table)), [ast.alias(cls)], 1) 

510 ) 

511 ann = ast.Name(id=cls, ctx=ast.Load()) 

512 class_def.body.append( 

513 ast.AnnAssign( 

514 target=ast.Name(id=table, ctx=ast.Store()), 

515 annotation=Optional(ann) if optional else ann, 

516 value=Relationship(back_populates=_c2t(class_def.name) + "s"), 

517 simple=1, 

518 ) 

519 ) 

520 back_populates[table].append(_c2t(class_def.name)) 

521 

522 

523def _t2c(table_name: str): 

524 """Table name to class name""" 

525 return TABLE_TO_CLS.get(table_name, table_name.capitalize()) 

526 

527 

528def _c2t(class_name: str): 

529 return CLS_TO_TABLE.get(class_name, class_name.lower()) 

530 

531 

532def _add_back_populates(module: ast.Module, table: str, back_populates: list[str]): 

533 *imports, class_def = module.body 

534 class_def = cast(ast.ClassDef, class_def) 

535 class_def.body.extend( 

536 [ 

537 ast.AnnAssign( 

538 target=ast.Name(id=p + "s", ctx=ast.Store()), 

539 annotation=ast.Subscript( 

540 value=ast.Name(id="list", ctx=ast.Load()), 

541 slice=ast.Constant(value=_t2c(p)), 

542 ctx=ast.Load(), 

543 ), 

544 value=Relationship(back_populates=table), 

545 simple=1, 

546 ) 

547 for p in back_populates 

548 ] 

549 ) 

550 module.body = ( 

551 imports 

552 + [ 

553 ast.ImportFrom("typing", names=[ast.alias("TYPE_CHECKING")], level=0), 

554 ast.If( 

555 test=ast.Name("TYPE_CHECKING", ctx=ast.Load()), 

556 body=[ 

557 ast.ImportFrom( 

558 to_snake(cls := _t2c(p)), names=[ast.alias(cls)], level=1 

559 ) 

560 for p in back_populates 

561 ], 

562 orelse=[], 

563 ), 

564 ] 

565 + [class_def] 

566 ) 

567 

568 

569def generate_modules(*modules: ModuleType): 

570 module_map: dict[str, ast.Module] = {} 

571 for module in modules: 

572 table = _c2t(module.__name__.split(".")[-1]) 

573 module_map[table] = generate_module(module) 

574 back_populates: dict[str, list[str]] = defaultdict(list) 

575 for module_name, mod in module_map.items(): 

576 _add_foreign_key(mod, back_populates) 

577 for table, back in back_populates.items(): 

578 _add_back_populates(module_map[table], table, back) 

579 for module_name, mod in module_map.items(): 

580 dest = ( 

581 Path(__file__).parent / "generated" / module_name.split(".")[-1] 

582 ).with_suffix(".py") 

583 dest.write_text(ast.unparse(ast.fix_missing_locations(mod)))