Coverage for src/lazy_imports_lite/_transformer.py: 100%
92 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-30 21:25 +0100
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-30 21:25 +0100
1import ast
2from typing import Any
4header = """
5import lazy_imports_lite._hooks as __lazy_imports_lite__
6globals=__lazy_imports_lite__.make_globals(lambda g=globals:g())
7"""
8header_ast = ast.parse(header).body
11class TransformModuleImports(ast.NodeTransformer):
12 def __init__(self):
13 self.transformed_imports = []
14 self.functions = []
15 self.context = []
17 self.globals = set()
18 self.locals = set()
19 self.in_function = False
21 def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
22 if self.context[-1] != "Module":
23 return node
25 if node.module == "__future__":
26 return node
28 new_nodes = []
29 for alias in node.names:
30 name = alias.asname or alias.name
32 module = "." * (node.level) + (node.module or "")
33 new_nodes.append(
34 ast.Assign(
35 targets=[ast.Name(id=name, ctx=ast.Store())],
36 value=ast.Call(
37 func=ast.Attribute(
38 value=ast.Name(id="__lazy_imports_lite__", ctx=ast.Load()),
39 attr="ImportFrom",
40 ctx=ast.Load(),
41 ),
42 args=[
43 ast.Name(id="__package__", ctx=ast.Load()),
44 ast.Constant(value=module, kind=None),
45 ast.Constant(alias.name, kind=None),
46 ],
47 keywords=[],
48 ),
49 )
50 )
51 self.transformed_imports.append(name)
52 return new_nodes
54 def visit_Import(self, node: ast.Import) -> Any:
55 if len(self.context) > 1:
56 return node
58 new_nodes = []
59 for alias in node.names:
60 if alias.asname:
61 name = alias.asname
62 new_nodes.append(
63 ast.Assign(
64 targets=[ast.Name(id=name, ctx=ast.Store())],
65 value=ast.Call(
66 func=ast.Attribute(
67 value=ast.Name(
68 id="__lazy_imports_lite__", ctx=ast.Load()
69 ),
70 attr="ImportAs",
71 ctx=ast.Load(),
72 ),
73 args=[ast.Constant(value=alias.name, kind=None)],
74 keywords=[],
75 ),
76 )
77 )
78 self.transformed_imports.append(name)
79 else:
80 name = alias.name.split(".")[0]
81 new_nodes.append(
82 ast.Assign(
83 targets=[ast.Name(id=name, ctx=ast.Store())],
84 value=ast.Call(
85 func=ast.Attribute(
86 value=ast.Name(
87 id="__lazy_imports_lite__", ctx=ast.Load()
88 ),
89 attr="Import",
90 ctx=ast.Load(),
91 ),
92 args=[ast.Constant(value=alias.name, kind=None)],
93 keywords=[],
94 ),
95 )
96 )
97 self.transformed_imports.append(name)
99 return new_nodes
101 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
102 return self.handle_function(node)
104 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
105 return self.handle_function(node)
107 def visit_Lambda(self, node: ast.Lambda) -> Any:
108 return self.handle_function(node)
110 def handle_function(self, function):
111 for field, value in ast.iter_fields(function):
112 if field != "body":
113 if isinstance(value, list):
114 setattr(function, field, [self.visit(v) for v in value])
115 elif isinstance(value, ast.AST):
116 setattr(function, field, self.visit(value))
117 self.functions.append(function)
119 return function
121 def handle_function_body(self, function: ast.FunctionDef):
122 args = [
123 *function.args.posonlyargs,
124 *function.args.args,
125 function.args.vararg,
126 *function.args.kwonlyargs,
127 function.args.kwarg,
128 ]
130 self.locals = {arg.arg for arg in args if arg is not None}
132 self.globals = set()
134 self.in_function = True
136 if isinstance(function.body, list):
137 function.body = [self.visit(b) for b in function.body]
138 else:
139 function.body = self.visit(function.body)
141 def visit_Global(self, node: ast.Global) -> Any:
142 self.globals.update(node.names)
143 return self.generic_visit(node)
145 def visit_Name(self, node: ast.Name) -> Any:
146 if isinstance(node.ctx, ast.Store) and (
147 node.id not in self.globals or not self.in_function
148 ):
149 self.locals.add(node.id)
151 if node.id in self.transformed_imports and node.id not in self.locals:
152 old_ctx = node.ctx
153 node.ctx = ast.Load()
154 return ast.Attribute(value=node, attr="v", ctx=old_ctx)
155 else:
156 return node
158 def visit_Module(self, module: ast.Module) -> Any:
159 module = self.generic_visit(module)
160 assert len(self.context) == 0
162 pos = 0
164 def is_import_from_future(node):
165 return (
166 isinstance(node, ast.Expr)
167 and isinstance(node.value, ast.Constant)
168 and isinstance(node.value.value, str)
169 or isinstance(node, ast.ImportFrom)
170 and node.module == "__future__"
171 )
173 if module.body:
174 while is_import_from_future(module.body[pos]):
175 pos += 1
176 module.body[pos:pos] = header_ast
178 self.context = ["FunctionBody"]
179 while self.functions:
180 f = self.functions.pop()
181 self.handle_function_body(f)
183 return module
185 def generic_visit(self, node: ast.AST) -> ast.AST:
186 ctx_len = len(self.context)
187 self.context.append(type(node).__name__)
188 result = super().generic_visit(node)
189 self.context = self.context[:ctx_len]
190 return result