Coverage for src/paperap/models/abstract/queryset.py: 80%

247 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 21:37 -0400

1""" 

2---------------------------------------------------------------------------- 

3 

4 METADATA: 

5 

6 File: queryset.py 

7 Project: paperap 

8 Created: 2025-03-04 

9 Version: 0.0.5 

10 Author: Jess Mann 

11 Email: jess@jmann.me 

12 Copyright (c) 2025 Jess Mann 

13 

14---------------------------------------------------------------------------- 

15 

16 LAST MODIFIED: 

17 

18 2025-03-04 By Jess Mann 

19 

20""" 

21 

22from __future__ import annotations 

23 

24import copy 

25import logging 

26from datetime import datetime 

27from string import Template 

28from typing import TYPE_CHECKING, Any, Generic, Iterable, Iterator, Optional, Self, Union, override 

29 

30from typing_extensions import TypeVar 

31from yarl import URL 

32 

33from paperap.exceptions import FilterDisabledError, MultipleObjectsFoundError, ObjectNotFoundError 

34 

35if TYPE_CHECKING: 

36 from paperap.models.abstract.model import BaseModel, StandardModel 

37 from paperap.resources.base import BaseResource, StandardResource 

38 

39_BaseModel = TypeVar("_BaseModel", bound="BaseModel", default="BaseModel", covariant=True) 

40_StandardModel = TypeVar("_StandardModel", bound="StandardModel", default="StandardModel", covariant=True) 

41 

42logger = logging.getLogger(__name__) 

43 

44 

45class BaseQuerySet(Iterable[_BaseModel], Generic[_BaseModel]): 

46 """ 

47 A lazy-loaded, chainable query interface for Paperless NGX resources. 

48 

49 BaseQuerySet provides pagination, filtering, and caching functionality similar to Django's QuerySet. 

50 It's designed to be lazy - only fetching data when it's actually needed. 

51 

52 Args: 

53 resource: The BaseResource instance. 

54 filters: Initial filter parameters. 

55 _cache: Optional internal result cache. 

56 _fetch_all: Whether all results have been fetched. 

57 _next_url: URL for the next page of results. 

58 _last_response: Optional last response from the API. 

59 _iter: Optional iterator for the results. 

60 

61 Returns: 

62 A new instance of BaseQuerySet. 

63 

64 Examples: 

65 # Create a QuerySet for documents 

66 >>> docs = client.documents() 

67 >>> for doc in docs: 

68 ... print(doc.id) 

69 1 

70 2 

71 3 

72 

73 """ 

74 

75 resource: "BaseResource[_BaseModel]" 

76 filters: dict[str, Any] 

77 _last_response: dict[str, Any] | None = None 

78 _result_cache: list[_BaseModel] = [] 

79 _fetch_all: bool = False 

80 _next_url: str | None = None 

81 _urls_fetched: list[str] = [] 

82 _iter: Iterator[_BaseModel] | None 

83 

84 def __init__( 

85 self, 

86 resource: "BaseResource[_BaseModel]", 

87 filters: Optional[dict[str, Any]] = None, 

88 _cache: Optional[list[_BaseModel]] = None, 

89 _fetch_all: bool = False, 

90 _next_url: str | None = None, 

91 _last_response: Optional[dict[str, Any]] = None, 

92 _iter: Optional[Iterator[_BaseModel]] = None, 

93 _urls_fetched: Optional[list[str]] = None, 

94 ) -> None: 

95 self.resource = resource 

96 self.filters = filters or {} 

97 self._result_cache = _cache or [] 

98 self._fetch_all = _fetch_all 

99 self._next_url = _next_url 

100 self._urls_fetched = _urls_fetched or [] 

101 self._last_response = _last_response 

102 self._iter = _iter 

103 

104 super().__init__() 

105 

106 @property 

107 def _model(self) -> type[_BaseModel]: 

108 """ 

109 Return the model class associated with the resource. 

110 

111 Returns: 

112 The model class 

113 

114 Examples: 

115 # Create a model instance 

116 >>> model = queryset._model(**params) 

117 

118 """ 

119 return self.resource.model_class 

120 

121 @property 

122 def _meta(self) -> "BaseModel.Meta": 

123 """ 

124 Return the model's metadata. 

125 

126 Returns: 

127 The model's metadata 

128 

129 Examples: 

130 # Get the model's metadata 

131 >>> queryset._meta.read_only_fields 

132 {'id', 'added', 'modified'} 

133 

134 """ 

135 return self._model._meta # pyright: ignore[reportPrivateUsage] # pylint: disable=protected-access 

136 

137 def _reset(self) -> None: 

138 """ 

139 Reset the QuerySet to its initial state. 

140 

141 This clears the result cache and resets the fetch state. 

142 """ 

143 self._result_cache = [] 

144 self._fetch_all = False 

145 self._next_url = None 

146 self._urls_fetched = [] 

147 self._last_response = None 

148 self._iter = None 

149 

150 def _update_filters(self, values: dict[str, Any]) -> None: 

151 """ 

152 Update the current filters with new values. 

153 

154 This updates the current queryset instance. It does not return a new instance. For that reason, 

155 do not call this directly. Call filter() or exclude() instead. 

156 

157 Args: 

158 values: New filter values to add 

159 

160 Raises: 

161 FilterDisabledError: If a filter is not allowed by the resource 

162 

163 Examples: 

164 # Update filters with new values 

165 queryset._update_filters({"correspondent": 1}) 

166 

167 # Update filters with multiple values 

168 queryset._update_filters({"correspondent": 1, "document_type": 2}) 

169 

170 """ 

171 for key, _value in values.items(): 

172 if not self._meta.filter_allowed(key): 

173 raise FilterDisabledError( 

174 f"Filtering by {key} for {self.resource.name} does not appear to be supported by the API." 

175 ) 

176 

177 if values: 

178 # Reset the cache if filters change 

179 self._reset() 

180 self.filters.update(**values) 

181 

182 def filter(self, **kwargs: Any) -> Self: 

183 """ 

184 Return a new QuerySet with the given filters applied. 

185 

186 Args: 

187 **kwargs: Filters to apply, where keys are field names and values are desired values. 

188 Supports Django-style lookups like field__contains, field__in, etc. 

189 

190 Returns: 

191 A new QuerySet with the additional filters applied 

192 

193 Examples: 

194 # Get documents with specific correspondent 

195 docs = client.documents.filter(correspondent=1) 

196 

197 # Get documents with specific correspondent and document type 

198 docs = client.documents.filter(correspondent=1, document_type=2) 

199 

200 # Get documents with title containing "invoice" 

201 docs = client.documents.filter(title__contains="invoice") 

202 

203 # Get documents with IDs in a list 

204 docs = client.documents.filter(id__in=[1, 2, 3]) 

205 

206 """ 

207 processed_filters = {} 

208 

209 for key, value in kwargs.items(): 

210 # Handle list values for __in lookups 

211 if isinstance(value, (list, set, tuple)): 

212 # Convert list to comma-separated string for the API 

213 processed_value = ",".join(str(item) for item in value) 

214 processed_filters[key] = processed_value 

215 # Handle boolean values 

216 elif isinstance(value, bool): 

217 processed_filters[key] = str(value).lower() 

218 # Handle normal values 

219 else: 

220 processed_filters[key] = value 

221 

222 return self._chain(filters={**self.filters, **processed_filters}) 

223 

224 def exclude(self, **kwargs: Any) -> Self: 

225 """ 

226 Return a new QuerySet excluding objects with the given filters. 

227 

228 Args: 

229 **kwargs: Filters to exclude, where keys are field names and values are excluded values 

230 

231 Returns: 

232 A new QuerySet excluding objects that match the filters 

233 

234 Examples: 

235 # Get documents with any correspondent except ID 1 

236 docs = client.documents.exclude(correspondent=1) 

237 

238 """ 

239 # Transform each key to its "not" equivalent 

240 exclude_filters = {} 

241 for key, value in kwargs.items(): 

242 if "__" in key: 

243 field, lookup = key.split("__", 1) 

244 # If it already has a "not" prefix, remove it 

245 if lookup.startswith("not_"): 

246 exclude_filters[f"{field}__{lookup[4:]}"] = value 

247 else: 

248 exclude_filters[f"{field}__not_{lookup}"] = value 

249 else: 

250 exclude_filters[f"{key}__not"] = value 

251 

252 return self._chain(filters={**self.filters, **exclude_filters}) 

253 

254 def get(self, pk: Any) -> _BaseModel: 

255 """ 

256 Retrieve a single object from the API. 

257 

258 Raises NotImplementedError. Subclasses may implement this. 

259 

260 Args: 

261 pk: The primary key (e.g. the id) of the object to retrieve 

262 

263 Returns: 

264 A single object matching the query 

265 

266 Raises: 

267 ObjectNotFoundError: If no object or multiple objects are found 

268 NotImplementedError: If the method is not implemented by the subclass 

269 

270 Examples: 

271 # Get document with ID 123 

272 doc = client.documents.get(123) 

273 

274 """ 

275 raise NotImplementedError("Getting a single resource is not defined by BaseModels without an id.") 

276 

277 def count(self) -> int: 

278 """ 

279 Return the total number of objects in the queryset. 

280 

281 Returns: 

282 The total count of objects matching the filters 

283 

284 Raises: 

285 NotImplementedError: If the response does not have a count attribute 

286 

287 """ 

288 # If we have a last response, we can use the "count" field 

289 if self._last_response: 

290 if (count := self._last_response.get("count")) is not None: 

291 return count 

292 raise NotImplementedError("Response does not have a count attribute.") 

293 

294 # Get one page of results, to populate last response 

295 _iter = self._request_iter(params=self.filters) 

296 

297 # TODO Hack 

298 for _ in _iter: 

299 break 

300 

301 if not self._last_response: 

302 # I don't think this should ever occur, but just in case. 

303 raise NotImplementedError("Requested iter, but no last response") 

304 

305 if (count := self._last_response.get("count")) is not None: 

306 return count 

307 

308 # I don't think this should ever occur, but just in case. 

309 raise NotImplementedError( 

310 f"Unexpected Error: Could not determine count of objects. Last response: {self._last_response}" 

311 ) 

312 

313 def count_this_page(self) -> int: 

314 """ 

315 Return the number of objects on the current page. 

316 

317 Returns: 

318 The count of objects on the current page 

319 

320 Raises: 

321 NotImplementedError: If _last_response is not set 

322 

323 """ 

324 # If we have a last response, we can count it without a new request 

325 if self._last_response: 

326 results = self._last_response.get("results", []) 

327 return len(results) 

328 

329 # Get one page of results, to populate last response 

330 _iter = self._request_iter(params=self.filters) 

331 

332 # TODO Hack 

333 for _ in _iter: 

334 break 

335 

336 if not self._last_response: 

337 # I don't think this should ever occur, but just in case. 

338 raise NotImplementedError("Requested iter, but no last response") 

339 

340 results = self._last_response.get("results", []) 

341 return len(results) 

342 

343 def all(self) -> Self: 

344 """ 

345 Return a new QuerySet that copies the current one. 

346 

347 Returns: 

348 A copy of the current BaseQuerySet 

349 

350 """ 

351 return self._chain() 

352 

353 def order_by(self, *fields: str) -> Self: 

354 """ 

355 Return a new QuerySet ordered by the specified fields. 

356 

357 Args: 

358 *fields: Field names to order by. Prefix with '-' for descending order. 

359 

360 Returns: 

361 A new QuerySet with the ordering applied 

362 

363 Examples: 

364 # Order documents by title ascending 

365 docs = client.documents.order_by('title') 

366 

367 # Order documents by added date descending 

368 docs = client.documents.order_by('-added') 

369 

370 """ 

371 if not fields: 

372 return self 

373 

374 # Combine with existing ordering if any 

375 ordering = self.filters.get("ordering", []) 

376 if isinstance(ordering, str): 

377 ordering = [ordering] 

378 elif not isinstance(ordering, list): 

379 ordering = list(ordering) 

380 

381 # Add new ordering fields 

382 new_ordering = ordering + list(fields) 

383 

384 # Join with commas for API 

385 ordering_param = ",".join(new_ordering) 

386 

387 return self._chain(filters={**self.filters, "ordering": ordering_param}) 

388 

389 def first(self) -> Optional[_BaseModel]: 

390 """ 

391 Return the first object in the QuerySet, or None if empty. 

392 

393 Returns: 

394 The first object or None if no objects match 

395 

396 """ 

397 if self._result_cache and len(self._result_cache) > 0: 

398 return self._result_cache[0] 

399 

400 # If not cached, create a copy limited to 1 result 

401 results = list(self._chain(filters={**self.filters, "limit": 1})) 

402 return results[0] if results else None 

403 

404 def last(self) -> Optional[_BaseModel]: 

405 """ 

406 Return the last object in the QuerySet, or None if empty. 

407 

408 Note: This requires fetching all results to determine the last one. 

409 

410 Returns: 

411 The last object or None if no objects match 

412 

413 """ 

414 # If we have all results, we can just return the last one 

415 if self._fetch_all: 

416 if self._result_cache and len(self._result_cache) > 0: 

417 return self._result_cache[-1] 

418 return None 

419 

420 # We need all results to get the last one 

421 self._fetch_all_results() 

422 

423 if self._result_cache and len(self._result_cache) > 0: 

424 return self._result_cache[-1] 

425 return None 

426 

427 def exists(self) -> bool: 

428 """ 

429 Return True if the QuerySet contains any results. 

430 

431 Returns: 

432 True if there are any objects matching the filters 

433 

434 """ 

435 # Check the cache before potentially making a new request 

436 if self._fetch_all or self._result_cache: 

437 return len(self._result_cache) > 0 

438 

439 # Check if there's at least one result 

440 return self.first() is not None 

441 

442 def none(self) -> Self: 

443 """ 

444 Return an empty QuerySet. 

445 

446 Returns: 

447 An empty QuerySet 

448 

449 """ 

450 return self._chain(filters={"limit": 0}) 

451 

452 def filter_field_by_str(self, field: str, value: str, *, exact: bool = True, case_insensitive: bool = True) -> Self: 

453 """ 

454 Filter a queryset based on a given field. 

455 

456 This allows subclasses to easily implement custom filter methods. 

457 

458 Args: 

459 field: The field name to filter by. 

460 value: The value to filter against. 

461 exact: Whether to filter by an exact match. 

462 case_insensitive: Whether the filter should be case-insensitive. 

463 

464 Returns: 

465 A new QuerySet instance with the filter applied. 

466 

467 """ 

468 if exact: 

469 lookup = f"{field}__iexact" if case_insensitive else field 

470 else: 

471 lookup = f"{field}__icontains" if case_insensitive else f"{field}__contains" 

472 

473 return self.filter(**{lookup: value}) 

474 

475 def _fetch_all_results(self) -> None: 

476 """ 

477 Fetch all results from the API and populate the cache. 

478 

479 Returns: 

480 None 

481 

482 """ 

483 if self._fetch_all: 

484 return 

485 

486 # Clear existing cache if any 

487 self._result_cache = [] 

488 

489 # Initial fetch 

490 iterator = self._request_iter(params=self.filters) 

491 

492 # Collect results from initial page 

493 self._result_cache.extend(list(iterator)) 

494 

495 # Fetch additional pages if available 

496 while self._last_response and self._next_url: 

497 iterator = self._request_iter(url=self._next_url) 

498 self._result_cache.extend(list(iterator)) 

499 

500 self._fetch_all = True 

501 

502 def _request_iter( 

503 self, url: str | URL | Template | None = None, params: Optional[dict[str, Any]] = None 

504 ) -> Iterator[_BaseModel]: 

505 """ 

506 Get an iterator of resources. 

507 

508 Args: 

509 url: The URL to request, if different from the resource's default. 

510 params: Query parameters. 

511 

512 Returns: 

513 An iterator over the resources. 

514 

515 Raises: 

516 NotImplementedError: If the request cannot be completed. 

517 

518 Examples: 

519 # Iterate over documents 

520 for doc in queryset._request_iter(): 

521 print(doc) 

522 

523 """ 

524 if not (response := self.resource.request_raw(url=url, params=params)): 

525 logger.debug("No response from request.") 

526 return 

527 

528 self._last_response = response 

529 

530 yield from self.resource.handle_response(**response) 

531 

532 def _get_next(self, response: dict[str, Any] | None = None) -> str | None: 

533 """ 

534 Get the next url, and adjust our references accordingly. 

535 """ 

536 # Allow passing a different response 

537 if response is None: 

538 response = self._last_response 

539 

540 # Last response is not set 

541 if not response or not (next_url := response.get("next")): 

542 self._next_url = None 

543 return None 

544 

545 # For safety, check both instance attributes, even though the first check isn't strictly necessary 

546 # this hopefully future proofs any changes to the implementation 

547 if next_url == self._next_url or next_url in self._urls_fetched: 

548 logger.debug( 

549 "Next URL was previously fetched. Stopping iteration. URL: %s, Already Fetched: %s", 

550 next_url, 

551 self._urls_fetched, 

552 ) 

553 self._next_url = None 

554 return None 

555 

556 # Cache it 

557 self._next_url = next_url 

558 self._urls_fetched.append(next_url) 

559 return self._next_url 

560 

561 def _chain(self, **kwargs: Any) -> Self: 

562 """ 

563 Return a copy of the current BaseQuerySet with updated attributes. 

564 

565 Args: 

566 **kwargs: Attributes to update in the new BaseQuerySet 

567 

568 Returns: 

569 A new QuerySet with the updated attributes 

570 

571 """ 

572 # Create a new BaseQuerySet with copied attributes 

573 clone = self.__class__(self.resource) 

574 

575 # Copy attributes from self 

576 clone.filters = copy.deepcopy(self.filters) 

577 # Do not copy the cache, fetch_all, etc, since filters may change it 

578 

579 # Update with provided kwargs 

580 for key, value in kwargs.items(): 

581 if key == "filters" and value: 

582 clone._update_filters(value) # pylint: disable=protected-access 

583 else: 

584 setattr(clone, key, value) 

585 

586 return clone 

587 

588 @override 

589 def __iter__(self) -> Iterator[_BaseModel]: 

590 """ 

591 Iterate over the objects in the QuerySet. 

592 

593 Returns: 

594 An iterator over the objects 

595 

596 """ 

597 # If we have a fully populated cache, use it 

598 if self._fetch_all: 

599 yield from self._result_cache 

600 

601 if not self._iter: 

602 # Start a new iteration 

603 self._iter = self._request_iter(params=self.filters) 

604 

605 # Yield objects from the current page 

606 for obj in self._iter: 

607 self._result_cache.append(obj) 

608 yield obj 

609 

610 self._get_next() 

611 

612 # If there are more pages, keep going 

613 count = 0 

614 while self._next_url: 

615 count += 1 

616 self._iter = self._request_iter(url=self._next_url) 

617 

618 # Yield objects from the current page 

619 for obj in self._iter: 

620 self._result_cache.append(obj) 

621 yield obj 

622 

623 self._get_next() 

624 

625 # We've fetched everything 

626 self._fetch_all = True 

627 self._iter = None 

628 

629 def __len__(self) -> int: 

630 """ 

631 Return the number of objects in the QuerySet. 

632 

633 Returns: 

634 The count of objects 

635 

636 """ 

637 return self.count() 

638 

639 def __bool__(self) -> bool: 

640 """ 

641 Return True if the QuerySet has any results. 

642 

643 Returns: 

644 True if there are any objects matching the filters 

645 

646 """ 

647 return self.exists() 

648 

649 def __getitem__(self, key: Union[int, slice]) -> Union[_BaseModel, list[_BaseModel]]: 

650 """ 

651 Retrieve an item or slice of items from the QuerySet. 

652 

653 Args: 

654 key: An integer index or slice 

655 

656 Returns: 

657 A single object or list of objects 

658 

659 Raises: 

660 IndexError: If the index is out of range 

661 

662 """ 

663 if isinstance(key, slice): 

664 # Handle slicing 

665 start = key.start if key.start is not None else 0 

666 stop = key.stop 

667 

668 if start < 0 or (stop is not None and stop < 0): 

669 # Negative indexing requires knowing the full size 

670 self._fetch_all_results() 

671 return self._result_cache[key] 

672 

673 # Optimize by using limit/offset if available 

674 if start == 0 and stop is not None: 

675 # Simple limit 

676 clone = self._chain(filters={**self.filters, "limit": stop}) 

677 results = list(clone) 

678 return results 

679 

680 if start > 0 and stop is not None: 

681 # Limit with offset 

682 clone = self._chain( 

683 filters={ 

684 **self.filters, 

685 "limit": stop - start, 

686 "offset": start, 

687 } 

688 ) 

689 results = list(clone) 

690 return results 

691 

692 if start > 0 and stop is None: 

693 # Just offset 

694 clone = self._chain(filters={**self.filters, "offset": start}) 

695 self._fetch_all_results() # We need all results after the offset 

696 return self._result_cache 

697 

698 # Default to fetching all and slicing 

699 self._fetch_all_results() 

700 return self._result_cache[key] 

701 

702 # Handle integer indexing 

703 if key < 0: 

704 # Negative indexing requires the full result set 

705 self._fetch_all_results() 

706 return self._result_cache[key] 

707 

708 # Positive indexing - we can optimize with limit/offset 

709 if len(self._result_cache) > key: 

710 # Already have this item cached 

711 return self._result_cache[key] 

712 

713 # Fetch specific item by position 

714 clone = self._chain(filters={**self.filters, "limit": 1, "offset": key}) 

715 results = list(clone) 

716 if not results: 

717 raise IndexError(f"BaseQuerySet index {key} out of range") 

718 return results[0] 

719 

720 def __contains__(self, item: Any) -> bool: 

721 """ 

722 Return True if the QuerySet contains the given object. 

723 

724 Args: 

725 item: The object to check for 

726 

727 Returns: 

728 True if the object is in the QuerySet 

729 

730 """ 

731 if not isinstance(item, self._model): 

732 return False 

733 

734 return any(obj == item for obj in self) 

735 

736 

737class StandardQuerySet(BaseQuerySet[_StandardModel], Generic[_StandardModel]): 

738 """ 

739 A queryset for StandardModel instances (i.e. BaseModels with standard fields, like id). 

740 

741 Returns: 

742 A new instance of StandardModel. 

743 

744 Raises: 

745 ValueError: If resource is not provided. 

746 

747 Examples: 

748 # Create a StandardModel instance 

749 model = StandardModel(id=1) 

750 

751 Args: 

752 resource: The BaseResource instance. 

753 filters: Initial filter parameters. 

754 

755 Returns: 

756 A new instance of StandardQuerySet. 

757 

758 Raises: 

759 ObjectNotFoundError: If no object or multiple objects are found. 

760 

761 Examples: 

762 # Create a StandardQuerySet for documents 

763 docs = StandardQuerySet(resource=client.documents) 

764 

765 """ 

766 

767 @override 

768 def get(self, pk: int) -> _StandardModel: 

769 """ 

770 Retrieve a single object from the API. 

771 

772 Args: 

773 pk: The ID of the object to retrieve 

774 

775 Returns: 

776 A single object matching the query 

777 

778 Raises: 

779 ObjectNotFoundError: If no object or multiple objects are found 

780 

781 Examples: 

782 # Get document with ID 123 

783 doc = client.documents.get(123) 

784 

785 """ 

786 # Attempt to find it in the result cache 

787 if self._result_cache: 

788 for obj in self._result_cache: 

789 if obj.id == pk: 

790 return obj 

791 

792 # Direct lookup by ID - use the resource's get method 

793 return self.resource.get(pk) 

794 

795 def id(self, value: int | list[int]) -> Self: 

796 """ 

797 Filter models by ID. 

798 

799 Args: 

800 value: The ID or list of IDs to filter by 

801 

802 Returns: 

803 Filtered QuerySet 

804 

805 """ 

806 if isinstance(value, list): 

807 return self.filter(id__in=value) 

808 return self.filter(id=value) 

809 

810 @override 

811 def __contains__(self, item: "StandardModel | int") -> bool: 

812 """ 

813 Return True if the QuerySet contains the given object. 

814 

815 NOTE: This method only ensures a match by ID, not by full object equality. 

816 This is intentional, as the object may be outdated or not fully populated. 

817 

818 Args: 

819 item: The object or ID to check for 

820 

821 Returns: 

822 True if the object is in the QuerySet 

823 

824 """ 

825 # ID means a match, even if the data is outdated 

826 pk = item if isinstance(item, int) else item.id 

827 return any(obj.id == pk for obj in self)