Coverage for zanj\loading.py: 80%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1from __future__ import annotations 

2 

3import json 

4import threading 

5import typing 

6import zipfile 

7from dataclasses import dataclass 

8from pathlib import Path 

9from typing import Any, Callable 

10 

11import numpy as np 

12 

13try: 

14 import pandas as pd # type: ignore[import] 

15 

16 pandas_DataFrame = pd.DataFrame # type: ignore[no-redef] 

17except ImportError: 

18 

19 class pandas_DataFrame: # type: ignore[no-redef] 

20 def __init__(self, *args, **kwargs): 

21 raise ImportError("cannot load pandas DataFrame, pandas is not installed") 

22 

23 

24import torch 

25from muutils.errormode import ErrorMode 

26from muutils.json_serialize.array import load_array 

27from muutils.json_serialize.json_serialize import ObjectPath 

28from muutils.json_serialize.util import ( 

29 _FORMAT_KEY, 

30 _REF_KEY, 

31 JSONdict, 

32 JSONitem, 

33 safe_getsource, 

34 string_as_lines, 

35) 

36from muutils.tensor_utils import DTYPE_MAP, TORCH_DTYPE_MAP 

37 

38from zanj.externals import ( 

39 GET_EXTERNAL_LOAD_FUNC, 

40 ZANJ_MAIN, 

41 ZANJ_META, 

42 ExternalItem, 

43 _ZANJ_pre, 

44) 

45 

46# pylint: disable=protected-access, dangerous-default-value 

47 

48 

49def _populate_externals_error_checking(key, item) -> bool: 

50 """checks that the key is valid for the item. returns "True" we need to augment the path by accessing the "data" element""" 

51 

52 # special case for not fully loaded external item which we still need to populate 

53 if isinstance(item, typing.Mapping): 

54 if (_FORMAT_KEY in item) and item[_FORMAT_KEY].endswith(":external"): 

55 if "data" in item: 

56 return True 

57 else: 

58 raise KeyError( 

59 f"expected an external item, but could not find data: {list(item.keys())}", 

60 f"{item[_FORMAT_KEY]}, {len(item) = }, {item.get('data', '<EMPTY>') = }", 

61 ) 

62 

63 # if it's a list, make sure the key is an int and that it's in range 

64 if isinstance(item, typing.Sequence): 

65 if not isinstance(key, int): 

66 raise TypeError(f"improper type: '{type(key) = }', expected int") 

67 if key >= len(item): 

68 raise IndexError(f"index out of range: '{key = }', expected < {len(item)}") 

69 

70 # if it's a dict, make sure that the key is a str and that it's in the dict 

71 elif isinstance(item, typing.Mapping): 

72 if not isinstance(key, str): 

73 raise TypeError(f"improper type: '{type(key) = }', expected str") 

74 if key not in item: 

75 raise KeyError(f"key not in dict: '{key = }', expected in {item.keys()}") 

76 

77 # otherwise, raise an error 

78 else: 

79 raise TypeError(f"improper type: '{type(item) = }', expected dict or list") 

80 

81 return False 

82 

83 

84@dataclass 

85class LoaderHandler: 

86 """handler for loading an object from a json file or a ZANJ archive""" 

87 

88 # TODO: add a separate "asserts" function? 

89 # right now, any asserts must happen in `check` or `load` which is annoying with lambdas 

90 

91 # (json_data, path) -> whether to use this handler 

92 check: Callable[[JSONitem, ObjectPath, _ZANJ_pre], bool] 

93 # function to load the object (json_data, path) -> loaded_obj 

94 load: Callable[[JSONitem, ObjectPath, _ZANJ_pre], Any] 

95 # unique identifier for the handler, saved in __muutils_format__ field 

96 uid: str 

97 # source package of the handler -- note that this might be overridden by ZANJ 

98 source_pckg: str 

99 # priority of the handler, defaults are all 0 

100 priority: int = 0 

101 # description of the handler 

102 desc: str = "(no description)" 

103 

104 def serialize(self) -> JSONdict: 

105 """serialize the handler info""" 

106 return { 

107 # get the code and doc of the check function 

108 "check": { 

109 "code": safe_getsource(self.check), 

110 "doc": string_as_lines(self.check.__doc__), 

111 }, 

112 # get the code and doc of the load function 

113 "load": { 

114 "code": safe_getsource(self.load), 

115 "doc": string_as_lines(self.load.__doc__), 

116 }, 

117 # get the uid, source_pckg, priority, and desc 

118 "uid": str(self.uid), 

119 "source_pckg": str(self.source_pckg), 

120 "priority": int(self.priority), 

121 "desc": str(self.desc), 

122 } 

123 

124 @classmethod 

125 def from_formattedclass(cls, fc: type, priority: int = 0): 

126 """create a loader from a class with `serialize`, `load` methods and `__muutils_format__` attribute""" 

127 assert hasattr(fc, "serialize") 

128 assert callable(fc.serialize) # type: ignore 

129 assert hasattr(fc, "load") 

130 assert callable(fc.load) # type: ignore 

131 assert hasattr(fc, _FORMAT_KEY) 

132 assert isinstance(fc.__muutils_format__, str) # type: ignore 

133 

134 return cls( 

135 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

136 json_item[_FORMAT_KEY] == fc.__muutils_format__ # type: ignore[attr-defined] 

137 ), 

138 load=lambda json_item, path=None, z=None: fc.load(json_item, path, z), # type: ignore[misc] 

139 uid=fc.__muutils_format__, # type: ignore[attr-defined] 

140 source_pckg=str(fc.__module__), 

141 priority=priority, 

142 desc=f"formatted class loader for {fc.__name__}", 

143 ) 

144 

145 

146# NOTE: there are type ignores on the loaders, since the type checking should be the responsibility of the check function 

147 

148LOADER_MAP_LOCK = threading.Lock() 

149 

150LOADER_MAP: dict[str, LoaderHandler] = { 

151 lh.uid: lh 

152 for lh in [ 

153 # array external 

154 LoaderHandler( 

155 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

156 isinstance(json_item, typing.Mapping) 

157 and _FORMAT_KEY in json_item 

158 and json_item[_FORMAT_KEY].startswith("numpy.ndarray") 

159 # and json_item["data"].dtype.name == json_item["dtype"] 

160 # and tuple(json_item["data"].shape) == tuple(json_item["shape"]) 

161 ), 

162 load=lambda json_item, path=None, z=None: np.array( # type: ignore[misc] 

163 load_array(json_item), dtype=DTYPE_MAP[json_item["dtype"]] 

164 ), 

165 uid="numpy.ndarray", 

166 source_pckg="zanj", 

167 desc="numpy.ndarray loader", 

168 ), 

169 LoaderHandler( 

170 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

171 isinstance(json_item, typing.Mapping) 

172 and _FORMAT_KEY in json_item 

173 and json_item[_FORMAT_KEY].startswith("torch.Tensor") 

174 # and json_item["data"].dtype.name == json_item["dtype"] 

175 # and tuple(json_item["data"].shape) == tuple(json_item["shape"]) 

176 ), 

177 load=lambda json_item, path=None, z=None: torch.tensor( # type: ignore[misc] 

178 load_array(json_item), dtype=TORCH_DTYPE_MAP[json_item["dtype"]] 

179 ), 

180 uid="torch.Tensor", 

181 source_pckg="zanj", 

182 desc="torch.Tensor loader", 

183 ), 

184 # pandas 

185 LoaderHandler( 

186 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

187 isinstance(json_item, typing.Mapping) 

188 and _FORMAT_KEY in json_item 

189 and json_item[_FORMAT_KEY].startswith("pandas.DataFrame") 

190 and "data" in json_item 

191 and isinstance(json_item["data"], typing.Sequence) 

192 ), 

193 load=lambda json_item, path=None, z=None: pandas_DataFrame( # type: ignore[misc] 

194 json_item["data"] 

195 ), 

196 uid="pandas.DataFrame", 

197 source_pckg="zanj", 

198 desc="pandas.DataFrame loader", 

199 ), 

200 # list/tuple external 

201 LoaderHandler( 

202 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

203 isinstance(json_item, typing.Mapping) 

204 and _FORMAT_KEY in json_item 

205 and json_item[_FORMAT_KEY].startswith("list") 

206 and "data" in json_item 

207 and isinstance(json_item["data"], typing.Sequence) 

208 ), 

209 load=lambda json_item, path=None, z=None: [ # type: ignore[misc] 

210 load_item_recursive(x, path, z) for x in json_item["data"] 

211 ], 

212 uid="list", 

213 source_pckg="zanj", 

214 desc="list loader, for externals", 

215 ), 

216 LoaderHandler( 

217 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] 

218 isinstance(json_item, typing.Mapping) 

219 and _FORMAT_KEY in json_item 

220 and json_item[_FORMAT_KEY].startswith("tuple") 

221 and "data" in json_item 

222 and isinstance(json_item["data"], typing.Sequence) 

223 ), 

224 load=lambda json_item, path=None, z=None: tuple( # type: ignore[misc] 

225 [load_item_recursive(x, path, z) for x in json_item["data"]] 

226 ), 

227 uid="tuple", 

228 source_pckg="zanj", 

229 desc="tuple loader, for externals", 

230 ), 

231 ] 

232} 

233 

234 

235def register_loader_handler(handler: LoaderHandler): 

236 """register a custom loader handler""" 

237 global LOADER_MAP, LOADER_MAP_LOCK 

238 with LOADER_MAP_LOCK: 

239 LOADER_MAP[handler.uid] = handler 

240 

241 

242def get_item_loader( 

243 json_item: JSONitem, 

244 path: ObjectPath, 

245 zanj: _ZANJ_pre | None = None, 

246 error_mode: ErrorMode = ErrorMode.WARN, 

247 # lh_map: dict[str, LoaderHandler] = LOADER_MAP, 

248) -> LoaderHandler | None: 

249 """get the loader for a json item""" 

250 global LOADER_MAP 

251 

252 # check if we recognize the format 

253 if isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item: 

254 if not isinstance(json_item[_FORMAT_KEY], str): 

255 raise TypeError( 

256 f"invalid __muutils_format__ type '{type(json_item[_FORMAT_KEY])}' in '{path=}': '{json_item[_FORMAT_KEY] = }'" 

257 ) 

258 if json_item[_FORMAT_KEY] in LOADER_MAP: 

259 return LOADER_MAP[json_item[_FORMAT_KEY]] # type: ignore[index] 

260 

261 # if we dont recognize the format, try to find a loader that can handle it 

262 for key, lh in LOADER_MAP.items(): 

263 if lh.check(json_item, path, zanj): 

264 return lh 

265 

266 # if we still dont have a loader, return None 

267 return None 

268 

269 

270def load_item_recursive( 

271 json_item: JSONitem, 

272 path: ObjectPath, 

273 zanj: _ZANJ_pre | None = None, 

274 error_mode: ErrorMode = ErrorMode.WARN, 

275 allow_not_loading: bool = True, 

276) -> Any: 

277 lh: LoaderHandler | None = get_item_loader( 

278 json_item=json_item, 

279 path=path, 

280 zanj=zanj, 

281 error_mode=error_mode, 

282 # lh_map=lh_map, 

283 ) 

284 

285 if lh is not None: 

286 # special case for serializable dataclasses 

287 if ( 

288 isinstance(json_item, typing.Mapping) 

289 and (_FORMAT_KEY in json_item) 

290 and ("SerializableDataclass" in json_item[_FORMAT_KEY]) # type: ignore[operator] 

291 ): 

292 # why this horribleness? 

293 # SerializableDataclass, if it has a field `x` which is also a SerializableDataclass, will automatically call `x.__class__.load()` 

294 # However, we need to load things in containers, as well as arrays 

295 processed_json_item: dict = { 

296 key: ( 

297 val 

298 if ( 

299 isinstance(val, typing.Mapping) 

300 and (_FORMAT_KEY in val) 

301 and ("SerializableDataclass" in val[_FORMAT_KEY]) 

302 ) 

303 else load_item_recursive( 

304 json_item=val, 

305 path=tuple(path) + (key,), 

306 zanj=zanj, 

307 error_mode=error_mode, 

308 ) 

309 ) 

310 for key, val in json_item.items() 

311 } 

312 

313 return lh.load(processed_json_item, path, zanj) 

314 

315 else: 

316 return lh.load(json_item, path, zanj) 

317 else: 

318 if isinstance(json_item, dict): 

319 return { 

320 key: load_item_recursive( 

321 json_item=json_item[key], 

322 path=tuple(path) + (key,), 

323 zanj=zanj, 

324 error_mode=error_mode, 

325 # lh_map=lh_map, 

326 ) 

327 for key in json_item 

328 } 

329 elif isinstance(json_item, list): 

330 return [ 

331 load_item_recursive( 

332 json_item=x, 

333 path=tuple(path) + (i,), 

334 zanj=zanj, 

335 error_mode=error_mode, 

336 # lh_map=lh_map, 

337 ) 

338 for i, x in enumerate(json_item) 

339 ] 

340 elif isinstance(json_item, (str, int, float, bool, type(None))): 

341 return json_item 

342 else: 

343 if allow_not_loading: 

344 return json_item 

345 else: 

346 raise ValueError( 

347 f"unknown type {type(json_item)} at {path}\n{json_item}" 

348 ) 

349 

350 

351def _each_item_in_externals( 

352 externals: dict[str, ExternalItem], 

353 json_data: JSONitem, 

354) -> typing.Iterable[tuple[str, ExternalItem, Any, ObjectPath]]: 

355 """note that you MUST use the raw iterator, dont try to turn into a list or something""" 

356 

357 sorted_externals: list[tuple[str, ExternalItem]] = sorted( 

358 externals.items(), key=lambda x: len(x[1].path) 

359 ) 

360 

361 for ext_path, ext_item in sorted_externals: 

362 # get the path to the item 

363 path: ObjectPath = tuple(ext_item.path) 

364 assert len(path) > 0 

365 assert all(isinstance(key, (str, int)) for key in path), ( 

366 f"improper types in path {path=}" 

367 ) 

368 # get the item 

369 item = json_data 

370 for i, key in enumerate(path): 

371 try: 

372 # ignores in this block are because we cannot know the type is indexable in static analysis 

373 # but, we check the types in the line below 

374 external_unloaded: bool = _populate_externals_error_checking(key, item) 

375 if external_unloaded: 

376 item = item["data"] # type: ignore 

377 item = item[key] # type: ignore[index] 

378 

379 except (KeyError, IndexError, TypeError) as e: 

380 raise KeyError( 

381 f"could not find '{key = }' at path '{ext_path = }', specifically at index '{i = }'", 

382 f"'{type(item) =}', '{len(item) = }', '{item.keys() if isinstance(item, dict) else None = }'", # type: ignore 

383 f"From error: {e = }", 

384 f"\n\n{item=}\n\n{ext_item=}", 

385 ) from e 

386 

387 yield (ext_path, ext_item, item, path) 

388 

389 

390class LoadedZANJ: 

391 """for loading a zanj file""" 

392 

393 def __init__( 

394 self, 

395 path: str | Path, 

396 zanj: _ZANJ_pre, 

397 ) -> None: 

398 # path and zanj object 

399 self._path: str = str(path) 

400 self._zanj: _ZANJ_pre = zanj 

401 

402 # load zip file 

403 _zipf: zipfile.ZipFile = zipfile.ZipFile(file=self._path, mode="r") 

404 

405 # load data 

406 self._meta: JSONdict = json.load(_zipf.open(ZANJ_META, "r")) 

407 self._json_data: JSONitem = json.load(_zipf.open(ZANJ_MAIN, "r")) 

408 

409 # read externals 

410 self._externals: dict[str, ExternalItem] = dict() 

411 for fname, ext_item in self._meta["externals_info"].items(): # type: ignore 

412 item_type: str = ext_item["item_type"] # type: ignore 

413 with _zipf.open(fname, "r") as fp: 

414 self._externals[fname] = ExternalItem( 

415 item_type=item_type, # type: ignore[arg-type] 

416 data=GET_EXTERNAL_LOAD_FUNC(item_type)(self, fp), 

417 path=ext_item["path"], # type: ignore 

418 ) 

419 

420 # close zip file 

421 _zipf.close() 

422 del _zipf 

423 

424 def populate_externals(self) -> None: 

425 """put all external items into the main json data""" 

426 

427 # loop over once, populating the externals only 

428 for ext_path, ext_item, item, path in _each_item_in_externals( 

429 self._externals, self._json_data 

430 ): 

431 # replace the item with the external item 

432 assert _REF_KEY in item # type: ignore 

433 assert item[_REF_KEY] == ext_path # type: ignore 

434 item["data"] = ext_item.data # type: ignore