Coverage for src/lazy_imports_lite/_transformer.py: 100%

93 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-12 09:19 +0100

1import ast 

2import typing 

3from typing import Any 

4 

5header = """ 

6import lazy_imports_lite._hooks as __lazy_imports_lite__ 

7globals=__lazy_imports_lite__.make_globals(lambda g=globals:g()) 

8""" 

9header_ast = ast.parse(header).body 

10 

11 

12class TransformModuleImports(ast.NodeTransformer): 

13 def __init__(self): 

14 self.transformed_imports = [] 

15 self.functions = [] 

16 self.context = [] 

17 

18 self.globals = set() 

19 self.locals = set() 

20 self.in_function = False 

21 

22 def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: 

23 if self.context[-1] != "Module": 

24 return node 

25 

26 if node.module == "__future__": 

27 return node 

28 

29 new_nodes = [] 

30 for alias in node.names: 

31 name = alias.asname or alias.name 

32 

33 module = "." * (node.level) + (node.module or "") 

34 new_nodes.append( 

35 ast.Assign( 

36 targets=[ast.Name(id=name, ctx=ast.Store())], 

37 value=ast.Call( 

38 func=ast.Attribute( 

39 value=ast.Name(id="__lazy_imports_lite__", ctx=ast.Load()), 

40 attr="ImportFrom", 

41 ctx=ast.Load(), 

42 ), 

43 args=[ 

44 ast.Name(id="__package__", ctx=ast.Load()), 

45 ast.Constant(value=module, kind=None), 

46 ast.Constant(alias.name, kind=None), 

47 ], 

48 keywords=[], 

49 ), 

50 ) 

51 ) 

52 self.transformed_imports.append(name) 

53 return new_nodes 

54 

55 def visit_Import(self, node: ast.Import) -> Any: 

56 if len(self.context) > 1: 

57 return node 

58 

59 new_nodes = [] 

60 for alias in node.names: 

61 if alias.asname: 

62 name = alias.asname 

63 new_nodes.append( 

64 ast.Assign( 

65 targets=[ast.Name(id=name, ctx=ast.Store())], 

66 value=ast.Call( 

67 func=ast.Attribute( 

68 value=ast.Name( 

69 id="__lazy_imports_lite__", ctx=ast.Load() 

70 ), 

71 attr="ImportAs", 

72 ctx=ast.Load(), 

73 ), 

74 args=[ast.Constant(value=alias.name, kind=None)], 

75 keywords=[], 

76 ), 

77 ) 

78 ) 

79 self.transformed_imports.append(name) 

80 else: 

81 name = alias.name.split(".")[0] 

82 new_nodes.append( 

83 ast.Assign( 

84 targets=[ast.Name(id=name, ctx=ast.Store())], 

85 value=ast.Call( 

86 func=ast.Attribute( 

87 value=ast.Name( 

88 id="__lazy_imports_lite__", ctx=ast.Load() 

89 ), 

90 attr="Import", 

91 ctx=ast.Load(), 

92 ), 

93 args=[ast.Constant(value=alias.name, kind=None)], 

94 keywords=[], 

95 ), 

96 ) 

97 ) 

98 self.transformed_imports.append(name) 

99 

100 return new_nodes 

101 

102 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: 

103 return self.handle_function(node) 

104 

105 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: 

106 return self.handle_function(node) 

107 

108 def visit_Lambda(self, node: ast.Lambda) -> Any: 

109 return self.handle_function(node) 

110 

111 def handle_function(self, function): 

112 for field, value in ast.iter_fields(function): 

113 if field != "body": 

114 if isinstance(value, list): 

115 setattr(function, field, [self.visit(item) for item in value]) 

116 elif isinstance(value, ast.AST): 

117 setattr(function, field, self.visit(value)) 

118 self.functions.append(function) 

119 

120 return function 

121 

122 def handle_function_body(self, function: ast.FunctionDef): 

123 args = [ 

124 *function.args.posonlyargs, 

125 *function.args.args, 

126 function.args.vararg, 

127 *function.args.kwonlyargs, 

128 function.args.kwarg, 

129 ] 

130 

131 self.locals = {arg.arg for arg in args if arg is not None} 

132 

133 self.globals = set() 

134 

135 self.in_function = True 

136 

137 if isinstance(function.body, list): 

138 function.body = [self.visit(b) for b in function.body] 

139 else: 

140 function.body = self.visit(function.body) 

141 

142 def visit_Global(self, node: ast.Global) -> Any: 

143 self.globals.update(node.names) 

144 return self.generic_visit(node) 

145 

146 def visit_Name(self, node: ast.Name) -> Any: 

147 if isinstance(node.ctx, ast.Store) and ( 

148 node.id not in self.globals or not self.in_function 

149 ): 

150 self.locals.add(node.id) 

151 

152 if node.id in self.transformed_imports and node.id not in self.locals: 

153 old_ctx = node.ctx 

154 node.ctx = ast.Load() 

155 return ast.Attribute(value=node, attr="_lazy_value", ctx=old_ctx) 

156 else: 

157 return node 

158 

159 def visit_Module(self, module: ast.Module) -> Any: 

160 module = typing.cast(ast.Module, self.generic_visit(module)) 

161 assert len(self.context) == 0 

162 

163 pos = 0 

164 

165 def is_import_from_future(node): 

166 return ( 

167 isinstance(node, ast.Expr) 

168 and isinstance(node.value, ast.Constant) 

169 and isinstance(node.value.value, str) 

170 or isinstance(node, ast.ImportFrom) 

171 and node.module == "__future__" 

172 ) 

173 

174 if module.body: 

175 while is_import_from_future(module.body[pos]): 

176 pos += 1 

177 module.body[pos:pos] = header_ast 

178 

179 self.context = ["FunctionBody"] 

180 while self.functions: 

181 f = self.functions.pop() 

182 self.handle_function_body(f) 

183 

184 return module 

185 

186 def generic_visit(self, node: ast.AST) -> ast.AST: 

187 ctx_len = len(self.context) 

188 self.context.append(type(node).__name__) 

189 result = super().generic_visit(node) 

190 self.context = self.context[:ctx_len] 

191 return result