Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/fields/json.py: 47%

341 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1import json 

2import warnings 

3 

4from plain import exceptions, preflight 

5from plain.models import expressions, lookups 

6from plain.models.constants import LOOKUP_SEP 

7from plain.models.db import NotSupportedError, connections, router 

8from plain.models.fields import TextField 

9from plain.models.lookups import ( 

10 FieldGetDbPrepValueMixin, 

11 PostgresOperatorLookup, 

12 Transform, 

13) 

14from plain.utils.deprecation import RemovedInDjango51Warning 

15 

16from . import Field 

17from .mixins import CheckFieldDefaultMixin 

18 

19__all__ = ["JSONField"] 

20 

21 

22class JSONField(CheckFieldDefaultMixin, Field): 

23 empty_strings_allowed = False 

24 description = "A JSON object" 

25 default_error_messages = { 

26 "invalid": "Value must be valid JSON.", 

27 } 

28 _default_hint = ("dict", "{}") 

29 

30 def __init__( 

31 self, 

32 name=None, 

33 encoder=None, 

34 decoder=None, 

35 **kwargs, 

36 ): 

37 if encoder and not callable(encoder): 

38 raise ValueError("The encoder parameter must be a callable object.") 

39 if decoder and not callable(decoder): 

40 raise ValueError("The decoder parameter must be a callable object.") 

41 self.encoder = encoder 

42 self.decoder = decoder 

43 super().__init__(name, **kwargs) 

44 

45 def check(self, **kwargs): 

46 errors = super().check(**kwargs) 

47 databases = kwargs.get("databases") or [] 

48 errors.extend(self._check_supported(databases)) 

49 return errors 

50 

51 def _check_supported(self, databases): 

52 errors = [] 

53 for db in databases: 

54 if not router.allow_migrate_model(db, self.model): 

55 continue 

56 connection = connections[db] 

57 if ( 

58 self.model._meta.required_db_vendor 

59 and self.model._meta.required_db_vendor != connection.vendor 

60 ): 

61 continue 

62 if not ( 

63 "supports_json_field" in self.model._meta.required_db_features 

64 or connection.features.supports_json_field 

65 ): 

66 errors.append( 

67 preflight.Error( 

68 f"{connection.display_name} does not support JSONFields.", 

69 obj=self.model, 

70 id="fields.E180", 

71 ) 

72 ) 

73 return errors 

74 

75 def deconstruct(self): 

76 name, path, args, kwargs = super().deconstruct() 

77 if self.encoder is not None: 

78 kwargs["encoder"] = self.encoder 

79 if self.decoder is not None: 

80 kwargs["decoder"] = self.decoder 

81 return name, path, args, kwargs 

82 

83 def from_db_value(self, value, expression, connection): 

84 if value is None: 

85 return value 

86 # Some backends (SQLite at least) extract non-string values in their 

87 # SQL datatypes. 

88 if isinstance(expression, KeyTransform) and not isinstance(value, str): 

89 return value 

90 try: 

91 return json.loads(value, cls=self.decoder) 

92 except json.JSONDecodeError: 

93 return value 

94 

95 def get_internal_type(self): 

96 return "JSONField" 

97 

98 def get_db_prep_value(self, value, connection, prepared=False): 

99 # RemovedInDjango51Warning: When the deprecation ends, replace with: 

100 # if ( 

101 # isinstance(value, expressions.Value) 

102 # and isinstance(value.output_field, JSONField) 

103 # ): 

104 # value = value.value 

105 # elif hasattr(value, "as_sql"): ... 

106 if isinstance(value, expressions.Value): 

107 if isinstance(value.value, str) and not isinstance( 

108 value.output_field, JSONField 

109 ): 

110 try: 

111 value = json.loads(value.value, cls=self.decoder) 

112 except json.JSONDecodeError: 

113 value = value.value 

114 else: 

115 warnings.warn( 

116 "Providing an encoded JSON string via Value() is deprecated. " 

117 f"Use Value({value!r}, output_field=JSONField()) instead.", 

118 category=RemovedInDjango51Warning, 

119 ) 

120 elif isinstance(value.output_field, JSONField): 

121 value = value.value 

122 else: 

123 return value 

124 elif hasattr(value, "as_sql"): 

125 return value 

126 return connection.ops.adapt_json_value(value, self.encoder) 

127 

128 def get_db_prep_save(self, value, connection): 

129 if value is None: 

130 return value 

131 return self.get_db_prep_value(value, connection) 

132 

133 def get_transform(self, name): 

134 transform = super().get_transform(name) 

135 if transform: 

136 return transform 

137 return KeyTransformFactory(name) 

138 

139 def validate(self, value, model_instance): 

140 super().validate(value, model_instance) 

141 try: 

142 json.dumps(value, cls=self.encoder) 

143 except TypeError: 

144 raise exceptions.ValidationError( 

145 self.error_messages["invalid"], 

146 code="invalid", 

147 params={"value": value}, 

148 ) 

149 

150 def value_to_string(self, obj): 

151 return self.value_from_object(obj) 

152 

153 

154def compile_json_path(key_transforms, include_root=True): 

155 path = ["$"] if include_root else [] 

156 for key_transform in key_transforms: 

157 try: 

158 num = int(key_transform) 

159 except ValueError: # non-integer 

160 path.append(".") 

161 path.append(json.dumps(key_transform)) 

162 else: 

163 path.append(f"[{num}]") 

164 return "".join(path) 

165 

166 

167class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup): 

168 lookup_name = "contains" 

169 postgres_operator = "@>" 

170 

171 def as_sql(self, compiler, connection): 

172 if not connection.features.supports_json_field_contains: 

173 raise NotSupportedError( 

174 "contains lookup is not supported on this database backend." 

175 ) 

176 lhs, lhs_params = self.process_lhs(compiler, connection) 

177 rhs, rhs_params = self.process_rhs(compiler, connection) 

178 params = tuple(lhs_params) + tuple(rhs_params) 

179 return f"JSON_CONTAINS({lhs}, {rhs})", params 

180 

181 

182class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup): 

183 lookup_name = "contained_by" 

184 postgres_operator = "<@" 

185 

186 def as_sql(self, compiler, connection): 

187 if not connection.features.supports_json_field_contains: 

188 raise NotSupportedError( 

189 "contained_by lookup is not supported on this database backend." 

190 ) 

191 lhs, lhs_params = self.process_lhs(compiler, connection) 

192 rhs, rhs_params = self.process_rhs(compiler, connection) 

193 params = tuple(rhs_params) + tuple(lhs_params) 

194 return f"JSON_CONTAINS({rhs}, {lhs})", params 

195 

196 

197class HasKeyLookup(PostgresOperatorLookup): 

198 logical_operator = None 

199 

200 def compile_json_path_final_key(self, key_transform): 

201 # Compile the final key without interpreting ints as array elements. 

202 return f".{json.dumps(key_transform)}" 

203 

204 def as_sql(self, compiler, connection, template=None): 

205 # Process JSON path from the left-hand side. 

206 if isinstance(self.lhs, KeyTransform): 

207 lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs( 

208 compiler, connection 

209 ) 

210 lhs_json_path = compile_json_path(lhs_key_transforms) 

211 else: 

212 lhs, lhs_params = self.process_lhs(compiler, connection) 

213 lhs_json_path = "$" 

214 sql = template % lhs 

215 # Process JSON path from the right-hand side. 

216 rhs = self.rhs 

217 rhs_params = [] 

218 if not isinstance(rhs, list | tuple): 

219 rhs = [rhs] 

220 for key in rhs: 

221 if isinstance(key, KeyTransform): 

222 *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection) 

223 else: 

224 rhs_key_transforms = [key] 

225 *rhs_key_transforms, final_key = rhs_key_transforms 

226 rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False) 

227 rhs_json_path += self.compile_json_path_final_key(final_key) 

228 rhs_params.append(lhs_json_path + rhs_json_path) 

229 # Add condition for each key. 

230 if self.logical_operator: 

231 sql = f"({self.logical_operator.join([sql] * len(rhs_params))})" 

232 return sql, tuple(lhs_params) + tuple(rhs_params) 

233 

234 def as_mysql(self, compiler, connection): 

235 return self.as_sql( 

236 compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)" 

237 ) 

238 

239 def as_postgresql(self, compiler, connection): 

240 if isinstance(self.rhs, KeyTransform): 

241 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection) 

242 for key in rhs_key_transforms[:-1]: 

243 self.lhs = KeyTransform(key, self.lhs) 

244 self.rhs = rhs_key_transforms[-1] 

245 return super().as_postgresql(compiler, connection) 

246 

247 def as_sqlite(self, compiler, connection): 

248 return self.as_sql( 

249 compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL" 

250 ) 

251 

252 

253class HasKey(HasKeyLookup): 

254 lookup_name = "has_key" 

255 postgres_operator = "?" 

256 prepare_rhs = False 

257 

258 

259class HasKeys(HasKeyLookup): 

260 lookup_name = "has_keys" 

261 postgres_operator = "?&" 

262 logical_operator = " AND " 

263 

264 def get_prep_lookup(self): 

265 return [str(item) for item in self.rhs] 

266 

267 

268class HasAnyKeys(HasKeys): 

269 lookup_name = "has_any_keys" 

270 postgres_operator = "?|" 

271 logical_operator = " OR " 

272 

273 

274class HasKeyOrArrayIndex(HasKey): 

275 def compile_json_path_final_key(self, key_transform): 

276 return compile_json_path([key_transform], include_root=False) 

277 

278 

279class CaseInsensitiveMixin: 

280 """ 

281 Mixin to allow case-insensitive comparison of JSON values on MySQL. 

282 MySQL handles strings used in JSON context using the utf8mb4_bin collation. 

283 Because utf8mb4_bin is a binary collation, comparison of JSON values is 

284 case-sensitive. 

285 """ 

286 

287 def process_lhs(self, compiler, connection): 

288 lhs, lhs_params = super().process_lhs(compiler, connection) 

289 if connection.vendor == "mysql": 

290 return f"LOWER({lhs})", lhs_params 

291 return lhs, lhs_params 

292 

293 def process_rhs(self, compiler, connection): 

294 rhs, rhs_params = super().process_rhs(compiler, connection) 

295 if connection.vendor == "mysql": 

296 return f"LOWER({rhs})", rhs_params 

297 return rhs, rhs_params 

298 

299 

300class JSONExact(lookups.Exact): 

301 can_use_none_as_rhs = True 

302 

303 def process_rhs(self, compiler, connection): 

304 rhs, rhs_params = super().process_rhs(compiler, connection) 

305 # Treat None lookup values as null. 

306 if rhs == "%s" and rhs_params == [None]: 

307 rhs_params = ["null"] 

308 if connection.vendor == "mysql": 

309 func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params) 

310 rhs %= tuple(func) 

311 return rhs, rhs_params 

312 

313 

314class JSONIContains(CaseInsensitiveMixin, lookups.IContains): 

315 pass 

316 

317 

318JSONField.register_lookup(DataContains) 

319JSONField.register_lookup(ContainedBy) 

320JSONField.register_lookup(HasKey) 

321JSONField.register_lookup(HasKeys) 

322JSONField.register_lookup(HasAnyKeys) 

323JSONField.register_lookup(JSONExact) 

324JSONField.register_lookup(JSONIContains) 

325 

326 

327class KeyTransform(Transform): 

328 postgres_operator = "->" 

329 postgres_nested_operator = "#>" 

330 

331 def __init__(self, key_name, *args, **kwargs): 

332 super().__init__(*args, **kwargs) 

333 self.key_name = str(key_name) 

334 

335 def preprocess_lhs(self, compiler, connection): 

336 key_transforms = [self.key_name] 

337 previous = self.lhs 

338 while isinstance(previous, KeyTransform): 

339 key_transforms.insert(0, previous.key_name) 

340 previous = previous.lhs 

341 lhs, params = compiler.compile(previous) 

342 return lhs, params, key_transforms 

343 

344 def as_mysql(self, compiler, connection): 

345 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) 

346 json_path = compile_json_path(key_transforms) 

347 return f"JSON_EXTRACT({lhs}, %s)", tuple(params) + (json_path,) 

348 

349 def as_postgresql(self, compiler, connection): 

350 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) 

351 if len(key_transforms) > 1: 

352 sql = f"({lhs} {self.postgres_nested_operator} %s)" 

353 return sql, tuple(params) + (key_transforms,) 

354 try: 

355 lookup = int(self.key_name) 

356 except ValueError: 

357 lookup = self.key_name 

358 return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,) 

359 

360 def as_sqlite(self, compiler, connection): 

361 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) 

362 json_path = compile_json_path(key_transforms) 

363 datatype_values = ",".join( 

364 [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values] 

365 ) 

366 return ( 

367 f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) " 

368 f"THEN JSON_TYPE({lhs}, %s) ELSE JSON_EXTRACT({lhs}, %s) END)" 

369 ), (tuple(params) + (json_path,)) * 3 

370 

371 

372class KeyTextTransform(KeyTransform): 

373 postgres_operator = "->>" 

374 postgres_nested_operator = "#>>" 

375 output_field = TextField() 

376 

377 def as_mysql(self, compiler, connection): 

378 if connection.mysql_is_mariadb: 

379 # MariaDB doesn't support -> and ->> operators (see MDEV-13594). 

380 sql, params = super().as_mysql(compiler, connection) 

381 return f"JSON_UNQUOTE({sql})", params 

382 else: 

383 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) 

384 json_path = compile_json_path(key_transforms) 

385 return f"({lhs} ->> %s)", tuple(params) + (json_path,) 

386 

387 @classmethod 

388 def from_lookup(cls, lookup): 

389 transform, *keys = lookup.split(LOOKUP_SEP) 

390 if not keys: 

391 raise ValueError("Lookup must contain key or index transforms.") 

392 for key in keys: 

393 transform = cls(key, transform) 

394 return transform 

395 

396 

397KT = KeyTextTransform.from_lookup 

398 

399 

400class KeyTransformTextLookupMixin: 

401 """ 

402 Mixin for combining with a lookup expecting a text lhs from a JSONField 

403 key lookup. On PostgreSQL, make use of the ->> operator instead of casting 

404 key values to text and performing the lookup on the resulting 

405 representation. 

406 """ 

407 

408 def __init__(self, key_transform, *args, **kwargs): 

409 if not isinstance(key_transform, KeyTransform): 

410 raise TypeError( 

411 "Transform should be an instance of KeyTransform in order to " 

412 "use this lookup." 

413 ) 

414 key_text_transform = KeyTextTransform( 

415 key_transform.key_name, 

416 *key_transform.source_expressions, 

417 **key_transform.extra, 

418 ) 

419 super().__init__(key_text_transform, *args, **kwargs) 

420 

421 

422class KeyTransformIsNull(lookups.IsNull): 

423 # key__isnull=False is the same as has_key='key' 

424 def as_sqlite(self, compiler, connection): 

425 template = "JSON_TYPE(%s, %%s) IS NULL" 

426 if not self.rhs: 

427 template = "JSON_TYPE(%s, %%s) IS NOT NULL" 

428 return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql( 

429 compiler, 

430 connection, 

431 template=template, 

432 ) 

433 

434 

435class KeyTransformIn(lookups.In): 

436 def resolve_expression_parameter(self, compiler, connection, sql, param): 

437 sql, params = super().resolve_expression_parameter( 

438 compiler, 

439 connection, 

440 sql, 

441 param, 

442 ) 

443 if ( 

444 not hasattr(param, "as_sql") 

445 and not connection.features.has_native_json_field 

446 ): 

447 if connection.vendor == "mysql" or ( 

448 connection.vendor == "sqlite" 

449 and params[0] not in connection.ops.jsonfield_datatype_values 

450 ): 

451 sql = "JSON_EXTRACT(%s, '$')" 

452 if connection.vendor == "mysql" and connection.mysql_is_mariadb: 

453 sql = f"JSON_UNQUOTE({sql})" 

454 return sql, params 

455 

456 

457class KeyTransformExact(JSONExact): 

458 def process_rhs(self, compiler, connection): 

459 if isinstance(self.rhs, KeyTransform): 

460 return super(lookups.Exact, self).process_rhs(compiler, connection) 

461 rhs, rhs_params = super().process_rhs(compiler, connection) 

462 if connection.vendor == "sqlite": 

463 func = [] 

464 for value in rhs_params: 

465 if value in connection.ops.jsonfield_datatype_values: 

466 func.append("%s") 

467 else: 

468 func.append("JSON_EXTRACT(%s, '$')") 

469 rhs %= tuple(func) 

470 return rhs, rhs_params 

471 

472 

473class KeyTransformIExact( 

474 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact 

475): 

476 pass 

477 

478 

479class KeyTransformIContains( 

480 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains 

481): 

482 pass 

483 

484 

485class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith): 

486 pass 

487 

488 

489class KeyTransformIStartsWith( 

490 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith 

491): 

492 pass 

493 

494 

495class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith): 

496 pass 

497 

498 

499class KeyTransformIEndsWith( 

500 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith 

501): 

502 pass 

503 

504 

505class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex): 

506 pass 

507 

508 

509class KeyTransformIRegex( 

510 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex 

511): 

512 pass 

513 

514 

515class KeyTransformNumericLookupMixin: 

516 def process_rhs(self, compiler, connection): 

517 rhs, rhs_params = super().process_rhs(compiler, connection) 

518 if not connection.features.has_native_json_field: 

519 rhs_params = [json.loads(value) for value in rhs_params] 

520 return rhs, rhs_params 

521 

522 

523class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan): 

524 pass 

525 

526 

527class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual): 

528 pass 

529 

530 

531class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan): 

532 pass 

533 

534 

535class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual): 

536 pass 

537 

538 

539KeyTransform.register_lookup(KeyTransformIn) 

540KeyTransform.register_lookup(KeyTransformExact) 

541KeyTransform.register_lookup(KeyTransformIExact) 

542KeyTransform.register_lookup(KeyTransformIsNull) 

543KeyTransform.register_lookup(KeyTransformIContains) 

544KeyTransform.register_lookup(KeyTransformStartsWith) 

545KeyTransform.register_lookup(KeyTransformIStartsWith) 

546KeyTransform.register_lookup(KeyTransformEndsWith) 

547KeyTransform.register_lookup(KeyTransformIEndsWith) 

548KeyTransform.register_lookup(KeyTransformRegex) 

549KeyTransform.register_lookup(KeyTransformIRegex) 

550 

551KeyTransform.register_lookup(KeyTransformLt) 

552KeyTransform.register_lookup(KeyTransformLte) 

553KeyTransform.register_lookup(KeyTransformGt) 

554KeyTransform.register_lookup(KeyTransformGte) 

555 

556 

557class KeyTransformFactory: 

558 def __init__(self, key_name): 

559 self.key_name = key_name 

560 

561 def __call__(self, *args, **kwargs): 

562 return KeyTransform(self.key_name, *args, **kwargs)