Coverage for src/fastoai/models/_codegen.py: 99%
157 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
1import ast
2import inspect
3from collections import defaultdict
4from pathlib import Path
5from types import ModuleType
6from typing import cast
8from pydantic.alias_generators import to_snake
10ID_PREFIXES = {
11 "FileObject": "file-",
12 "Assistant": "asst_",
13 "RunStep": "step_",
14 "Message": "msg_",
15}
17CLS_TO_TABLE = {
18 "FileObject": "file",
19 "RunStep": "step",
20}
21TABLE_TO_CLS = {
22 "file": "FileObject",
23 "step": "RunStep",
24}
27def ruff_check(path: Path):
28 import os
29 import subprocess
31 from ruff.__main__ import find_ruff_bin
33 ruff = os.fsdecode(find_ruff_bin())
34 argv = ["check", "--fix", path]
36 subprocess.run([ruff, *argv])
39def Annotated(type_, *args) -> ast.Subscript:
40 return ast.Subscript(
41 value=ast.Name(id="Annotated", ctx=ast.Load()),
42 slice=ast.Tuple(
43 elts=[type_, *args],
44 ctx=ast.Load(),
45 ),
46 ctx=ast.Load(),
47 )
50def Field(*args, **kwargs) -> ast.Call:
51 return ast.Call(
52 func=ast.Name(id="Field", ctx=ast.Load()),
53 args=list(args),
54 keywords=[ast.keyword(arg=k, value=v) for k, v in kwargs.items()],
55 )
58def _fix_id(class_def: ast.ClassDef):
59 id_field = cast(ast.AnnAssign, class_def.body[0])
60 id_field.annotation = Annotated(
61 id_field.annotation,
62 Field(
63 primary_key=ast.Constant(value=True),
64 default_factory=ast.Call(
65 func=ast.Name(id="random_id_with_prefix", ctx=ast.Load()),
66 args=[
67 ast.Constant(
68 value=ID_PREFIXES.get(
69 class_def.name, class_def.name.lower() + "_"
70 )
71 )
72 ],
73 keywords=[],
74 ),
75 ),
76 )
79def _fix_literal(class_def: ast.ClassDef):
80 for n in class_def.body:
81 if (
82 isinstance(n, ast.AnnAssign)
83 and isinstance(n.annotation, ast.Subscript)
84 and isinstance(n.annotation.value, ast.Name)
85 and n.annotation.value.id == "Literal"
86 ):
87 t = n.annotation
88 f = Field(
89 sa_type=ast.Call(
90 func=ast.Name(id="Enum", ctx=ast.Load()),
91 args=[
92 elt
93 for elt in cast(ast.Tuple, cast(ast.Subscript, t).slice).elts
94 ],
95 keywords=[],
96 )
97 )
98 n.annotation = Annotated(t, f)
101def _get_object_literal(n: ast.ClassDef) -> str:
102 object_field = next(
103 (
104 a
105 for a in n.body
106 if isinstance(a, ast.AnnAssign)
107 and isinstance(a.target, ast.Name)
108 and a.target.id == "object"
109 )
110 )
111 n.body.remove(object_field)
112 return cast(
113 str,
114 cast(ast.Constant, cast(ast.Subscript, object_field.annotation).slice).value,
115 )
118def _fix_optional(class_def: ast.ClassDef):
119 for stmt in class_def.body:
120 if (
121 isinstance(stmt, ast.AnnAssign)
122 and isinstance(ann := stmt.annotation, ast.Subscript)
123 and isinstance(ann.value, ast.Name)
124 and ann.value.id == "Optional"
125 ):
126 if isinstance(ann.slice, ast.Name) and ann.slice.id in {
127 "str",
128 "int",
129 "float",
130 "bool",
131 }:
132 stmt.annotation = ast.BinOp(
133 left=ann.slice,
134 op=ast.BitOr(),
135 right=ast.Name(id="None", ctx=ast.Load()),
136 )
137 else:
138 if (
139 isinstance(ann.slice, ast.Subscript)
140 and isinstance(ann.slice.value, ast.Name)
141 and ann.slice.value.id == "List"
142 ):
143 ann.slice.value.id = "list"
144 t = ast.BinOp(
145 left=ann.slice,
146 op=ast.BitOr(),
147 right=ast.Name(id="None", ctx=ast.Load()),
148 )
149 as_sa_type_call = ast.Call(
150 func=ast.Name(id="as_sa_type", ctx=ast.Load()),
151 args=[ann.slice],
152 keywords=[],
153 )
154 f = Field(
155 sa_type=as_sa_type_call,
156 nullable=ast.Constant(value=True),
157 )
158 stmt.annotation = Annotated(t, f)
161def _fix_list(class_def: ast.ClassDef):
162 for stmt in class_def.body:
163 if (
164 isinstance(ann_assign := stmt, ast.AnnAssign)
165 and isinstance(ann := ann_assign.annotation, ast.Subscript)
166 and isinstance(ann.value, ast.Name)
167 ):
168 if ann.value.id == "List":
169 ann.value.id = "list"
170 as_sa_type_call = ast.Call(
171 func=ast.Name(id="as_sa_type", ctx=ast.Load()),
172 args=[ann],
173 keywords=[],
174 )
175 ann_assign.annotation = Annotated(
176 ann,
177 Field(
178 default_factory=ast.Name("list", ctx=ast.Load()),
179 sa_type=as_sa_type_call,
180 ),
181 )
184def _fix_name(class_def: ast.ClassDef):
185 for stmt in class_def.body:
186 if (
187 isinstance(ann_assign := stmt, ast.AnnAssign)
188 and isinstance(ann := ann_assign.annotation, ast.Name)
189 and ann.id not in {"str", "int", "float", "bool"}
190 ):
191 as_sa_type_call = ast.Call(
192 func=ast.Name(id="as_sa_type", ctx=ast.Load()),
193 args=[ann],
194 keywords=[],
195 )
196 ann_assign.annotation = Annotated(
197 ann,
198 Field(
199 sa_type=as_sa_type_call,
200 ),
201 )
204def Optional(type_) -> ast.BinOp:
205 return ast.BinOp(
206 left=type_,
207 op=ast.BitOr(),
208 right=ast.Name(id="None", ctx=ast.Load()),
209 )
212def _fix_timestamp(class_def: ast.ClassDef):
213 timestamps = []
214 optional = False
215 for stmt in class_def.body:
216 if (
217 isinstance(ann_assign := stmt, ast.AnnAssign)
218 and isinstance(ann_assign.target, ast.Name)
219 and ann_assign.target.id.endswith("_at")
220 ):
221 if isinstance(ann_assign.annotation, ast.Name):
222 ann_assign.annotation = Annotated(
223 ast.Name(id="datetime", ctx=ast.Load()),
224 Field(default_factory=ast.Name(id="now", ctx=ast.Load())),
225 )
226 else:
227 ann_assign.annotation = Optional(
228 ast.Name(id="datetime", ctx=ast.Load())
229 )
230 optional = True
231 timestamps.append(ann_assign.target.id)
232 if timestamps:
233 name_int = ast.Name(id="int", ctx=ast.Load())
234 name_datetime = ast.Name(id="datetime", ctx=ast.Load())
235 serializor = ast.FunctionDef(
236 name="serialize_datetime",
237 args=ast.arguments(
238 posonlyargs=[],
239 args=[
240 ast.arg("self"),
241 ast.arg(
242 arg="dt",
243 annotation=Optional(name_datetime)
244 if optional
245 else name_datetime,
246 ),
247 ],
248 kwonlyargs=[],
249 kw_defaults=[],
250 defaults=[],
251 ),
252 body=[
253 *(
254 [
255 ast.If(
256 test=ast.Compare(
257 left=ast.Name(id="dt", ctx=ast.Load()),
258 ops=[ast.Is()],
259 comparators=[ast.Name(id="None", ctx=ast.Load())],
260 ),
261 body=[ast.Return(value=ast.Constant(value=None))],
262 orelse=[],
263 )
264 ]
265 if optional
266 else []
267 ),
268 ast.Return(
269 value=ast.Call(
270 func=name_int,
271 args=[
272 ast.Call(
273 func=ast.Attribute(
274 value=ast.Name(id="dt", ctx=ast.Load()),
275 attr="timestamp",
276 ctx=ast.Load(),
277 ),
278 args=[],
279 keywords=[],
280 )
281 ],
282 keywords=[],
283 )
284 ),
285 ],
286 decorator_list=[
287 ast.Call(
288 func=ast.Name(id="field_serializer", ctx=ast.Load()),
289 args=[ast.Constant(value=f) for f in timestamps],
290 keywords=[],
291 )
292 ],
293 returns=Optional(name_int) if optional else name_int,
294 type_params=[],
295 )
296 class_def.body.append(serializor)
299def _to_openai_model(
300 openai_model_name: str, object_literal: str
301) -> ast.AsyncFunctionDef:
302 return ast.AsyncFunctionDef(
303 name="to_openai_model",
304 args=ast.arguments(
305 posonlyargs=[],
306 args=[ast.arg(arg="self")],
307 kwonlyargs=[],
308 kw_defaults=[],
309 defaults=[],
310 ),
311 body=[
312 ast.Assign(
313 targets=[ast.Name(id="value", ctx=ast.Store())],
314 value=ast.Call(
315 func=ast.Attribute(
316 value=ast.Name(id="self", ctx=ast.Load()),
317 attr="model_dump",
318 ctx=ast.Load(),
319 ),
320 args=[],
321 keywords=[ast.keyword("by_alias", ast.Constant(True))],
322 ),
323 ),
324 ast.Assign(
325 targets=[
326 ast.Subscript(
327 value=ast.Name(id="value", ctx=ast.Load()),
328 slice=ast.Constant(value="object"),
329 ctx=ast.Store(),
330 )
331 ],
332 value=ast.Constant(value=object_literal),
333 ),
334 ast.Return(
335 value=ast.Call(
336 func=ast.Attribute(
337 value=ast.Name(id=openai_model_name, ctx=ast.Load()),
338 attr="model_validate",
339 ctx=ast.Load(),
340 ),
341 args=[ast.Name(id="value", ctx=ast.Load())],
342 keywords=[],
343 )
344 ),
345 ],
346 decorator_list=[],
347 returns=ast.Name(id=openai_model_name, ctx=ast.Load()),
348 type_params=[],
349 )
352def generate_module(module: ModuleType) -> ast.Module:
353 tree = ast.parse(inspect.getsource(module))
354 all_defs = next(
355 node
356 for node in tree.body
357 if isinstance(node, ast.Assign)
358 and len(node.targets) == 1
359 and isinstance(node.targets[0], ast.Name)
360 and node.targets[0].id == "__all__"
361 )
362 all_defs_index = tree.body.index(all_defs)
363 imports = cast(list[ast.ImportFrom], tree.body[:all_defs_index])
364 body = tree.body[all_defs_index + 1 :]
365 for n in imports:
366 match n:
367 case n if n.module == "typing_extensions":
368 n.module = "typing"
369 case n if n.level != 0:
370 n.module = ".".join(
371 (module.__name__.split("."))[: -n.level] + [n.module or ""]
372 )
373 n.level = 0
374 imports.extend(
375 [
376 ast.ImportFrom("datetime", [ast.alias("datetime")], 0),
377 ast.ImportFrom("typing", [ast.alias("Annotated")], 0),
378 ast.ImportFrom("sqlalchemy.ext.asyncio", [ast.alias("AsyncAttrs")], 0),
379 ast.ImportFrom(
380 "sqlmodel",
381 [
382 ast.alias("SQLModel"),
383 ast.alias("Enum"),
384 ast.alias("Field"),
385 ast.alias("Relationship"),
386 ],
387 0,
388 ),
389 ast.ImportFrom("pydantic", [ast.alias("field_serializer")], 0),
390 ast.ImportFrom("_metadata", [ast.alias("WithMetadata")], 2),
391 ast.ImportFrom("_types", [ast.alias("as_sa_type")], 2),
392 ast.ImportFrom(
393 "_utils", [ast.alias("now"), ast.alias("random_id_with_prefix")], 2
394 ),
395 ]
396 )
397 class_def = cast(ast.ClassDef, body.pop())
399 class_def.body = [
400 stmt for stmt in class_def.body if isinstance(stmt, ast.AnnAssign)
401 ]
402 metadata_field = next(
403 (
404 a
405 for a in class_def.body
406 if isinstance(a, ast.AnnAssign)
407 and isinstance(a.target, ast.Name)
408 and a.target.id == "metadata"
409 ),
410 None,
411 )
412 if metadata_field:
413 class_def.body.remove(metadata_field)
414 class_def.bases = [
415 ast.Name(id="AsyncAttrs", ctx=ast.Load()),
416 ast.Name(
417 id="WithMetadata" if metadata_field else "SQLModel",
418 ctx=ast.Load(),
419 ),
420 ]
421 class_def.keywords = [ast.keyword(arg="table", value=ast.Constant(value=True))]
422 _fix_id(class_def)
423 if tablename := CLS_TO_TABLE.get(class_def.name):
424 class_def.body.insert(
425 0,
426 ast.Assign(
427 targets=[ast.Name(id="__tablename__", ctx=ast.Store())],
428 value=ast.Constant(value=tablename),
429 ),
430 )
431 imports.append(
432 ast.ImportFrom(
433 module.__name__,
434 [ast.alias(class_def.name, f"_{class_def.name}")],
435 0,
436 )
437 )
438 object_literal = _get_object_literal(class_def)
439 class_def.body.append(_to_openai_model(f"_{class_def.name}", object_literal))
440 _fix_literal(class_def)
441 _fix_optional(class_def)
442 _fix_list(class_def)
443 _fix_name(class_def)
444 _fix_timestamp(class_def)
445 for node in body:
446 if isinstance(node, ast.ClassDef):
447 imports.append(
448 ast.ImportFrom(
449 module.__name__,
450 [ast.alias(node.name)],
451 0,
452 )
453 )
454 elif isinstance(node, ast.AnnAssign):
455 if isinstance(node.annotation, ast.Subscript):
456 imports.append(
457 ast.ImportFrom(
458 module.__name__,
459 [ast.alias(cast(ast.Name, node.annotation.value).id)],
460 0,
461 )
462 )
463 elif (
464 isinstance(node.annotation, ast.Name)
465 and node.annotation.id == "TypeAlias"
466 ):
467 imports.append(
468 ast.ImportFrom(
469 module.__name__,
470 [ast.alias(cast(ast.Name, node.target).id)],
471 0,
472 )
473 )
474 tree.body = imports + tree.body[-1:]
475 return tree
478def Relationship(*, back_populates: str) -> ast.Call:
479 return ast.Call(
480 func=ast.Name(id="Relationship", ctx=ast.Load()),
481 args=[],
482 keywords=[
483 ast.keyword(arg="back_populates", value=ast.Constant(value=back_populates))
484 ],
485 )
488def _add_foreign_key(module: ast.Module, back_populates: dict[str, list[str]]):
489 class_def = cast(ast.ClassDef, module.body[-1])
490 for stmt in class_def.body:
491 if (
492 isinstance(ann_assign := stmt, ast.AnnAssign)
493 and isinstance(field_name := ann_assign.target, ast.Name)
494 and field_name.id.endswith("_id")
495 ):
496 table = field_name.id[:-3]
497 f = Field(foreign_key=ast.Constant(value=f"{table}.id"))
498 optional = False
499 if isinstance(ann_assign.annotation, ast.BinOp):
500 optional = True
501 f.keywords.append(
502 ast.keyword(arg="nullable", value=ast.Constant(value=True))
503 )
504 ann_assign.annotation = Annotated(
505 ann_assign.annotation,
506 f,
507 )
508 module.body.insert(
509 0, ast.ImportFrom(to_snake(cls := _t2c(table)), [ast.alias(cls)], 1)
510 )
511 ann = ast.Name(id=cls, ctx=ast.Load())
512 class_def.body.append(
513 ast.AnnAssign(
514 target=ast.Name(id=table, ctx=ast.Store()),
515 annotation=Optional(ann) if optional else ann,
516 value=Relationship(back_populates=_c2t(class_def.name) + "s"),
517 simple=1,
518 )
519 )
520 back_populates[table].append(_c2t(class_def.name))
523def _t2c(table_name: str):
524 """Table name to class name"""
525 return TABLE_TO_CLS.get(table_name, table_name.capitalize())
528def _c2t(class_name: str):
529 return CLS_TO_TABLE.get(class_name, class_name.lower())
532def _add_back_populates(module: ast.Module, table: str, back_populates: list[str]):
533 *imports, class_def = module.body
534 class_def = cast(ast.ClassDef, class_def)
535 class_def.body.extend(
536 [
537 ast.AnnAssign(
538 target=ast.Name(id=p + "s", ctx=ast.Store()),
539 annotation=ast.Subscript(
540 value=ast.Name(id="list", ctx=ast.Load()),
541 slice=ast.Constant(value=_t2c(p)),
542 ctx=ast.Load(),
543 ),
544 value=Relationship(back_populates=table),
545 simple=1,
546 )
547 for p in back_populates
548 ]
549 )
550 module.body = (
551 imports
552 + [
553 ast.ImportFrom("typing", names=[ast.alias("TYPE_CHECKING")], level=0),
554 ast.If(
555 test=ast.Name("TYPE_CHECKING", ctx=ast.Load()),
556 body=[
557 ast.ImportFrom(
558 to_snake(cls := _t2c(p)), names=[ast.alias(cls)], level=1
559 )
560 for p in back_populates
561 ],
562 orelse=[],
563 ),
564 ]
565 + [class_def]
566 )
569def generate_modules(*modules: ModuleType):
570 module_map: dict[str, ast.Module] = {}
571 for module in modules:
572 table = _c2t(module.__name__.split(".")[-1])
573 module_map[table] = generate_module(module)
574 back_populates: dict[str, list[str]] = defaultdict(list)
575 for module_name, mod in module_map.items():
576 _add_foreign_key(mod, back_populates)
577 for table, back in back_populates.items():
578 _add_back_populates(module_map[table], table, back)
579 for module_name, mod in module_map.items():
580 dest = (
581 Path(__file__).parent / "generated" / module_name.split(".")[-1]
582 ).with_suffix(".py")
583 dest.write_text(ast.unparse(ast.fix_missing_locations(mod)))