Source code for betty.assertion.error

"""
Provide assertion failures.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from contextlib import contextmanager
from textwrap import indent
from typing import Iterator, Self, TYPE_CHECKING, TypeAlias, TypeVar

from typing_extensions import override

from betty.error import UserFacingError
from betty.locale import UNDETERMINED_LOCALE
from betty.locale.localizable import _, Localizable
from betty.locale.localized import LocalizedStr

if TYPE_CHECKING:
    from collections.abc import Sequence, MutableSequence
    from betty.locale.localizer import Localizer


_AssertionContextValueT = TypeVar("_AssertionContextValueT")


[docs] class AssertionContext(ABC): """ The context in which an assertion is invoked. """
[docs] @abstractmethod def format(self) -> str: """ Format this context to a string. """ pass
[docs] class Attr(AssertionContext): """ An object attribute context. """
[docs] def __init__(self, attr: str): self._attr = attr
[docs] @override def format(self) -> str: return f".{self._attr}"
[docs] class Index(AssertionContext): """ A sequence index context. """
[docs] def __init__(self, index: int): self._index = index
[docs] @override def format(self) -> str: return f"[{self._index}]"
[docs] class Key(AssertionContext): """ A mapping key context. """
[docs] def __init__(self, key: str): self._key = key
[docs] @override def format(self) -> str: return f'["{self._key}"]'
Contextey: TypeAlias = AssertionContext | Localizable class _Contexts(Localizable): def __init__(self, context: AssertionContext): self.contexts: MutableSequence[AssertionContext] = [context] @override def localize(self, localizer: Localizer) -> LocalizedStr: return LocalizedStr( "data" + "".join(context.format() for context in self.contexts), locale=UNDETERMINED_LOCALE, )
[docs] def localizable_contexts(*contexts: Contextey) -> Sequence[Localizable]: """ The contexts as :py:class:`betty.locale.localizable.Localizable` instances. """ localizable_contexts: MutableSequence[Localizable] = [] for context in contexts: if isinstance(context, Localizable): localizable_contexts.append(context) else: try: last_context = localizable_contexts[-1] except IndexError: pass else: if isinstance(last_context, _Contexts): last_context.contexts.append(context) continue localizable_contexts.append(_Contexts(context)) return localizable_contexts
[docs] class AssertionFailed(UserFacingError, ValueError): """ An assertion failure. """
[docs] def __init__( self, message: Localizable, *, contexts: tuple[Contextey, ...] | None = None ): super().__init__(message) self._contexts: tuple[Contextey, ...] = contexts or ()
[docs] @override def localize(self, localizer: Localizer) -> LocalizedStr: return LocalizedStr( ( super().localize(localizer) + "\n" + indent( "\n".join( context.localize(localizer) for context in localizable_contexts(*self.contexts) ), "- ", ) ).strip(), locale=localizer.locale, )
[docs] def raised(self, error_type: type[AssertionFailed]) -> bool: """ Check if the error matches the given error type. """ return isinstance(self, error_type)
@property def contexts(self) -> tuple[Contextey, ...]: """ Get the human-readable contexts describing where the error occurred in the source data. """ return self._contexts
[docs] def with_context(self, *contexts: Contextey) -> Self: """ Add a message describing the error's context. """ self_copy = self._copy() self_copy._contexts = (*reversed(contexts), *self._contexts) return self_copy
def _copy(self) -> Self: return type(self)(self._localizable_message)
[docs] class AssertionFailedGroup(AssertionFailed): """ A group of zero or more assertion failures. """
[docs] def __init__( self, errors: Sequence[AssertionFailed] | None = None, ): super().__init__(_("The following errors occurred")) self._errors: MutableSequence[AssertionFailed] = [] if errors is not None: self.append(*errors)
def __iter__(self) -> Iterator[AssertionFailed]: yield from self._errors
[docs] @override def localize(self, localizer: Localizer) -> LocalizedStr: return LocalizedStr( "\n\n".join((error.localize(localizer) for error in self._errors)), locale=localizer.locale, )
def __len__(self) -> int: return len(self._errors)
[docs] @override def raised(self, error_type: type[AssertionFailed]) -> bool: return any(error.raised(error_type) for error in self._errors)
@property def valid(self) -> bool: """ Check that this collection contains no errors. """ return len(self._errors) == 0 @property def invalid(self) -> bool: """ Check that this collection contains at least one error. """ return not self.valid
[docs] @contextmanager def assert_valid(self, *contexts: Contextey) -> Iterator[Self]: """ Assert that this collection contains no errors. """ if self.invalid: raise self with self.catch(*contexts): yield self if self.invalid: # type: ignore[redundant-expr] raise self
[docs] def append(self, *errors: AssertionFailed) -> None: """ Append errors to this collection. """ for error in errors: if isinstance(error, AssertionFailedGroup): self.append(*error) else: self._errors.append(error.with_context(*self._contexts))
[docs] @override def with_context(self, *contexts: Contextey) -> Self: self_copy = super().with_context(*contexts) self_copy._errors = [error.with_context(*contexts) for error in self._errors] return self_copy
@override def _copy(self) -> Self: return type(self)()
[docs] @contextmanager def catch(self, *contexts: Contextey) -> Iterator[AssertionFailedGroup]: """ Catch any errors raised within this context manager and add them to the collection. :return: A new collection that will only contain any newly raised errors. """ context_errors: AssertionFailedGroup = AssertionFailedGroup() if contexts: context_errors = context_errors.with_context(*contexts) try: yield context_errors except AssertionFailed as e: context_errors.append(e) self.append(*context_errors)