Coverage for src/lazy_imports_lite/_transformer.py: 100%

92 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-30 21:25 +0100

1import ast 

2from typing import Any 

3 

4header = """ 

5import lazy_imports_lite._hooks as __lazy_imports_lite__ 

6globals=__lazy_imports_lite__.make_globals(lambda g=globals:g()) 

7""" 

8header_ast = ast.parse(header).body 

9 

10 

11class TransformModuleImports(ast.NodeTransformer): 

12 def __init__(self): 

13 self.transformed_imports = [] 

14 self.functions = [] 

15 self.context = [] 

16 

17 self.globals = set() 

18 self.locals = set() 

19 self.in_function = False 

20 

21 def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: 

22 if self.context[-1] != "Module": 

23 return node 

24 

25 if node.module == "__future__": 

26 return node 

27 

28 new_nodes = [] 

29 for alias in node.names: 

30 name = alias.asname or alias.name 

31 

32 module = "." * (node.level) + (node.module or "") 

33 new_nodes.append( 

34 ast.Assign( 

35 targets=[ast.Name(id=name, ctx=ast.Store())], 

36 value=ast.Call( 

37 func=ast.Attribute( 

38 value=ast.Name(id="__lazy_imports_lite__", ctx=ast.Load()), 

39 attr="ImportFrom", 

40 ctx=ast.Load(), 

41 ), 

42 args=[ 

43 ast.Name(id="__package__", ctx=ast.Load()), 

44 ast.Constant(value=module, kind=None), 

45 ast.Constant(alias.name, kind=None), 

46 ], 

47 keywords=[], 

48 ), 

49 ) 

50 ) 

51 self.transformed_imports.append(name) 

52 return new_nodes 

53 

54 def visit_Import(self, node: ast.Import) -> Any: 

55 if len(self.context) > 1: 

56 return node 

57 

58 new_nodes = [] 

59 for alias in node.names: 

60 if alias.asname: 

61 name = alias.asname 

62 new_nodes.append( 

63 ast.Assign( 

64 targets=[ast.Name(id=name, ctx=ast.Store())], 

65 value=ast.Call( 

66 func=ast.Attribute( 

67 value=ast.Name( 

68 id="__lazy_imports_lite__", ctx=ast.Load() 

69 ), 

70 attr="ImportAs", 

71 ctx=ast.Load(), 

72 ), 

73 args=[ast.Constant(value=alias.name, kind=None)], 

74 keywords=[], 

75 ), 

76 ) 

77 ) 

78 self.transformed_imports.append(name) 

79 else: 

80 name = alias.name.split(".")[0] 

81 new_nodes.append( 

82 ast.Assign( 

83 targets=[ast.Name(id=name, ctx=ast.Store())], 

84 value=ast.Call( 

85 func=ast.Attribute( 

86 value=ast.Name( 

87 id="__lazy_imports_lite__", ctx=ast.Load() 

88 ), 

89 attr="Import", 

90 ctx=ast.Load(), 

91 ), 

92 args=[ast.Constant(value=alias.name, kind=None)], 

93 keywords=[], 

94 ), 

95 ) 

96 ) 

97 self.transformed_imports.append(name) 

98 

99 return new_nodes 

100 

101 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: 

102 return self.handle_function(node) 

103 

104 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: 

105 return self.handle_function(node) 

106 

107 def visit_Lambda(self, node: ast.Lambda) -> Any: 

108 return self.handle_function(node) 

109 

110 def handle_function(self, function): 

111 for field, value in ast.iter_fields(function): 

112 if field != "body": 

113 if isinstance(value, list): 

114 setattr(function, field, [self.visit(v) for v in value]) 

115 elif isinstance(value, ast.AST): 

116 setattr(function, field, self.visit(value)) 

117 self.functions.append(function) 

118 

119 return function 

120 

121 def handle_function_body(self, function: ast.FunctionDef): 

122 args = [ 

123 *function.args.posonlyargs, 

124 *function.args.args, 

125 function.args.vararg, 

126 *function.args.kwonlyargs, 

127 function.args.kwarg, 

128 ] 

129 

130 self.locals = {arg.arg for arg in args if arg is not None} 

131 

132 self.globals = set() 

133 

134 self.in_function = True 

135 

136 if isinstance(function.body, list): 

137 function.body = [self.visit(b) for b in function.body] 

138 else: 

139 function.body = self.visit(function.body) 

140 

141 def visit_Global(self, node: ast.Global) -> Any: 

142 self.globals.update(node.names) 

143 return self.generic_visit(node) 

144 

145 def visit_Name(self, node: ast.Name) -> Any: 

146 if isinstance(node.ctx, ast.Store) and ( 

147 node.id not in self.globals or not self.in_function 

148 ): 

149 self.locals.add(node.id) 

150 

151 if node.id in self.transformed_imports and node.id not in self.locals: 

152 old_ctx = node.ctx 

153 node.ctx = ast.Load() 

154 return ast.Attribute(value=node, attr="v", ctx=old_ctx) 

155 else: 

156 return node 

157 

158 def visit_Module(self, module: ast.Module) -> Any: 

159 module = self.generic_visit(module) 

160 assert len(self.context) == 0 

161 

162 pos = 0 

163 

164 def is_import_from_future(node): 

165 return ( 

166 isinstance(node, ast.Expr) 

167 and isinstance(node.value, ast.Constant) 

168 and isinstance(node.value.value, str) 

169 or isinstance(node, ast.ImportFrom) 

170 and node.module == "__future__" 

171 ) 

172 

173 if module.body: 

174 while is_import_from_future(module.body[pos]): 

175 pos += 1 

176 module.body[pos:pos] = header_ast 

177 

178 self.context = ["FunctionBody"] 

179 while self.functions: 

180 f = self.functions.pop() 

181 self.handle_function_body(f) 

182 

183 return module 

184 

185 def generic_visit(self, node: ast.AST) -> ast.AST: 

186 ctx_len = len(self.context) 

187 self.context.append(type(node).__name__) 

188 result = super().generic_visit(node) 

189 self.context = self.context[:ctx_len] 

190 return result