"""
Entity associations.
"""
from __future__ import annotations
import weakref
from abc import abstractmethod, ABC
from typing import (
Generic,
cast,
Any,
Iterable,
TypeVar,
final,
Self,
overload,
TYPE_CHECKING,
)
from urllib.parse import quote
from typing_extensions import override
from betty.importlib import import_any
from betty.json.linked_data import LinkedDataDumpableProvider
from betty.json.schema import Schema, Array, OneOf, Null
from betty.model import (
Entity,
EntityReferenceSchema,
EntityReferenceCollectionSchema,
UserFacingEntity,
persistent_id,
)
from betty.model.collections import EntityCollection, SingleTypeEntityCollection
from betty.typing import internal
if TYPE_CHECKING:
from betty.project import Project
from betty.serde.dump import Dump
_T = TypeVar("_T")
_EntityT = TypeVar("_EntityT", bound=Entity)
_OwnerT = TypeVar("_OwnerT", bound=Entity)
_AssociateT = TypeVar("_AssociateT", bound=Entity)
async def _generate_associate_url(project: Project, associate: Entity) -> str | None:
if not persistent_id(associate):
return None
if not isinstance(associate, UserFacingEntity):
return None
url_generator = await project.url_generator
return url_generator.generate(
f"betty-static:///{associate.type.plugin_id()}/{quote(associate.id)}/index.json"
)
[docs]
class AssociationRequired(RuntimeError):
"""
Raised when an operation cannot be performed because the association in question is required.
"""
[docs]
@classmethod
def new(cls, association: _Association[_OwnerT, Any], owner: _OwnerT) -> Self:
"""
Create a new instance.
"""
return cls(
f"Association {association._owner_type_name}.{association.owner_attr_name} is required, but missing for {owner}."
)
class _Resolver(Generic[_T], ABC):
@abstractmethod
def resolve(self) -> _T:
"""
Return the resolved entity or entities.
:raises ResolutionError: Raised if resolution failed.
"""
pass
[docs]
class ToZeroOrOneResolver(Generic[_EntityT], _Resolver[_EntityT | None]):
"""
An object that can optionally resolve to an entity.
"""
pass
[docs]
class ToOneResolver(Generic[_EntityT], _Resolver[_EntityT]):
"""
An object that can resolve to an entity.
"""
pass
[docs]
class ToManyResolver(Generic[_EntityT], _Resolver[Iterable[_EntityT]]):
"""
An object that can resolve to a collection of entities.
"""
pass
class _TemporaryResolver(Generic[_T], _Resolver[_T]):
@override
def resolve(self) -> _T:
raise RuntimeError(
"This temporary resolver was supposed to be replaced. It intentionally cannot resolve itself."
)
[docs]
class TemporaryToZeroOrOneResolver(
Generic[_EntityT], _TemporaryResolver[_EntityT], ToZeroOrOneResolver[_EntityT]
):
"""
A 'temporary' to-zero-or-one resolver.
This is helpful to satisfy association requirements in multiple steps. Users **MUST** ensure that this resolver
is replaced by a real value, because the resolver will never be able to resolve itself.
"""
pass
[docs]
class TemporaryToOneResolver(
Generic[_EntityT], _TemporaryResolver[_EntityT], ToOneResolver[_EntityT]
):
"""
A 'temporary' to-one resolver.
This is helpful to satisfy association requirements in multiple steps. Users **MUST** ensure that this resolver
is replaced by a real value, because the resolver will never be able to resolve itself.
"""
pass
[docs]
class TemporaryToManyResolver(
Generic[_EntityT], _TemporaryResolver[_EntityT], ToManyResolver[_EntityT]
):
"""
A 'temporary' to-many resolver.
This is helpful to satisfy association requirements in multiple steps. Users **MUST** ensure that this resolver
is replaced by a real value, because the resolver will never be able to resolve itself.
"""
pass
class _Association(LinkedDataDumpableProvider[_OwnerT], Generic[_OwnerT, _AssociateT]):
def __init__(
self,
owner_type_name: str,
owner_attr_name: str,
associate_type_name: str,
*,
title: str | None = None,
description: str | None = None,
linked_data_embedded: bool = False,
):
self._owner_type_name = owner_type_name
self._owner_attr_name = owner_attr_name
self._internal_owner_attr_name = f"_{owner_attr_name}"
self._associate_type_name = associate_type_name
self._linked_data_embedded = linked_data_embedded
self._title = title
self._description = description
AssociationRegistry._register(self)
def __hash__(self) -> int:
return hash(
(
type(self),
self._owner_type_name,
self._owner_attr_name,
self._associate_type_name,
self._linked_data_embedded,
self._title,
self._description,
)
)
@property
def owner_type(self) -> type[_OwnerT]:
"""
The type of the owning entity that contains this association.
This may be an abstract class.
"""
return cast(
type[_OwnerT],
import_any(self._owner_type_name),
)
@property
def owner_attr_name(self) -> str:
"""
The name of the attribute on the owning entity that contains this association.
"""
return self._owner_attr_name
@property
def associate_type(self) -> type[_AssociateT]:
"""
The type of any associate entities.
This may be an abstract class.
"""
return cast(
type[_AssociateT],
import_any(self._associate_type_name),
)
@abstractmethod
def resolve(self, owner: _OwnerT) -> None:
"""
Resolve any associates the owner may have for this association.
"""
pass
@abstractmethod
def associate(self, owner: _OwnerT, associate: _AssociateT) -> None:
"""
Associate two entities.
"""
pass
@abstractmethod
def disassociate(self, owner: _OwnerT, associate: _AssociateT) -> None:
"""
Disassociate two entities.
:raises AssociationRequired: Raised if the association is required and the disassociation would leave it without
any associates.
"""
pass
@abstractmethod
def get_associates(self, owner: _OwnerT) -> Iterable[_AssociateT]:
"""
Get the associates for the given owner.
"""
pass
class _ToOneAssociation(
Generic[_OwnerT, _AssociateT], _Association[_OwnerT, _AssociateT]
):
@override
def associate(self, owner: _OwnerT, associate: _AssociateT) -> None:
self.__set__(owner, associate)
@override
def disassociate(self, owner: _OwnerT, associate: _AssociateT) -> None:
setattr(owner, self._internal_owner_attr_name, None)
@overload
def __get__(self, instance: None, owner: type[_OwnerT]) -> Self:
pass
@overload
def __get__(self, instance: _OwnerT, owner: type[_OwnerT]) -> _AssociateT:
pass
def __get__(self, instance: _OwnerT | None, owner: type[_OwnerT]):
if instance is None:
return self # type: ignore[return-value]
try:
value = getattr(instance, self._internal_owner_attr_name)
except AttributeError:
raise AssociationRequired.new(self, instance) from None
else:
if value is None:
raise AssociationRequired.new(self, instance)
assert not isinstance(value, _Resolver)
return cast(_AssociateT, value)
def __set__(
self, instance: _OwnerT, value: _AssociateT | ToOneResolver[_AssociateT]
) -> None:
setattr(instance, self._internal_owner_attr_name, value)
@override
def get_associates(self, owner: _OwnerT) -> Iterable[_AssociateT]:
yield self.__get__(owner, type(owner))
@override
async def linked_data_schema_for(self, project: Project) -> Schema:
schema = (
await self.associate_type.linked_data_schema(project)
if self._linked_data_embedded
else EntityReferenceSchema()
)
return OneOf(
schema,
Null(description="In case the entity is not publishable."),
title=self._title or schema.title,
description=self._description or schema.description,
)
@override
async def dump_linked_data_for(
self, project: Project, target: _OwnerT & Entity
) -> Dump:
associate = self.__get__(target, type(target))
if self._linked_data_embedded:
return await associate.dump_linked_data(project)
else:
return await _generate_associate_url(project, associate)
class _ToZeroOrOneAssociation(
Generic[_OwnerT, _AssociateT], _Association[_OwnerT, _AssociateT]
):
@override
def associate(self, owner: _OwnerT, associate: _AssociateT) -> None:
self.__set__(owner, associate)
@override
def disassociate(self, owner: _OwnerT, associate: _AssociateT) -> None:
if associate == self.__get__(owner, type(owner)):
self.__delete__(owner)
@overload
def __get__(self, instance: None, owner: type[_OwnerT]) -> Self:
pass
@overload
def __get__(self, instance: _OwnerT, owner: type[_OwnerT]) -> _AssociateT | None:
pass
def __get__(self, instance: _OwnerT | None, owner: type[_OwnerT]):
if instance is None:
return self # type: ignore[return-value]
try:
value = getattr(instance, self._internal_owner_attr_name)
except AttributeError:
setattr(instance, self._internal_owner_attr_name, None)
return None
else:
assert not isinstance(value, _Resolver)
return cast(_AssociateT | None, value)
def __set__(
self,
instance: _OwnerT,
value: _AssociateT
| ToZeroOrOneResolver[_AssociateT]
| ToOneResolver[_AssociateT]
| None,
) -> None:
setattr(instance, self._internal_owner_attr_name, value)
def __delete__(self, instance: _OwnerT) -> None:
self.__set__(instance, None)
@override
def get_associates(self, owner: _OwnerT) -> Iterable[_AssociateT]:
associate = self.__get__(owner, type(owner))
if associate is not None:
yield associate
@override
async def linked_data_schema_for(self, project: Project) -> Schema:
schema = (
await self.associate_type.linked_data_schema(project)
if self._linked_data_embedded
else EntityReferenceSchema()
)
return OneOf(
schema,
Null(description="In case the entity is not publishable."),
title=self._title or schema.title,
description=self._description or schema.description,
)
@override
async def dump_linked_data_for(
self, project: Project, target: _OwnerT & Entity
) -> Dump:
associate = self.__get__(target, type(target))
if associate is None:
return None
if self._linked_data_embedded:
return await associate.dump_linked_data(project)
else:
return await _generate_associate_url(project, associate)
@internal
class _ToManyAssociation(
Generic[_OwnerT, _AssociateT], _Association[_OwnerT, _AssociateT]
):
def _new_collection(self, instance: _OwnerT) -> EntityCollection[_AssociateT]:
return SingleTypeEntityCollection[_AssociateT](self.associate_type)
@overload
def __get__(self, instance: None, owner: type[_OwnerT]) -> Self:
pass
@overload
def __get__(
self, instance: _OwnerT, owner: type[_OwnerT]
) -> EntityCollection[_AssociateT]:
pass
def __get__(self, instance: _OwnerT | None, owner: type[_OwnerT]):
if instance is None:
return self # type: ignore[return-value]
try:
value = getattr(instance, self._internal_owner_attr_name)
except AttributeError:
value = self._new_collection(instance)
setattr(instance, self._internal_owner_attr_name, value)
return value
else:
assert not isinstance(value, _Resolver)
return cast(EntityCollection[_AssociateT], value)
def __set__(
self,
instance: _OwnerT,
value: Iterable[_AssociateT] | ToManyResolver[_AssociateT],
) -> None:
if isinstance(value, _Resolver):
setattr(instance, self._internal_owner_attr_name, value)
else:
self.__get__(instance, type(instance)).replace(*value)
def __delete__(self, instance: _OwnerT) -> None:
self.__get__(instance, type(instance)).clear()
@override
def associate(self, owner: _OwnerT, associate: _AssociateT) -> None:
self.__get__(owner, type(owner)).add(associate)
@override
def disassociate(self, owner: _OwnerT, associate: _AssociateT) -> None:
self.__get__(owner, type(owner)).remove(associate)
@override
def get_associates(self, owner: _OwnerT) -> Iterable[_AssociateT]:
yield from self.__get__(owner, type(owner))
@override
def resolve(self, owner: _OwnerT) -> None:
value = getattr(owner, self._internal_owner_attr_name, None)
if isinstance(value, _Resolver):
collection = self._new_collection(owner)
setattr(owner, self._internal_owner_attr_name, collection)
collection.add(*value.resolve())
@override
async def linked_data_schema_for(self, project: Project) -> Schema:
if self._linked_data_embedded:
return Array(
await self.associate_type.linked_data_schema(project),
title=self._title,
description=self._description,
)
return EntityReferenceCollectionSchema(
title=self._title, description=self._description
)
@override
async def dump_linked_data_for(
self, project: Project, target: _OwnerT & Entity
) -> Dump:
associates = self.__get__(target, type(target))
if self._linked_data_embedded:
return [
await associate.dump_linked_data(project) for associate in associates
]
return list(
filter(
None,
[
await _generate_associate_url(project, associate)
for associate in associates
],
)
)
class _BidirectionalAssociation(
Generic[_OwnerT, _AssociateT], _Association[_OwnerT, _AssociateT]
):
def __init__(
self,
owner_type_name: str,
owner_attr_name: str,
associate_type_name: str,
associate_attr_name: str,
*,
linked_data_embedded: bool = False,
title: str | None = None,
description: str | None = None,
):
self._associate_attr_name = associate_attr_name
super().__init__(
owner_type_name,
owner_attr_name,
associate_type_name,
title=title,
description=description,
linked_data_embedded=linked_data_embedded,
)
@override
def __hash__(self) -> int:
return hash((super().__hash__(), self._associate_attr_name))
@property
def associate_attr_name(self) -> str:
"""
The association's attribute name on the associate type.
"""
return self._associate_attr_name
def inverse(self) -> _BidirectionalAssociation[_AssociateT, _OwnerT]:
"""
Get the inverse association.
"""
association = AssociationRegistry.get_association(
self.associate_type, self.associate_attr_name
)
assert isinstance(association, _BidirectionalAssociation)
return association
[docs]
class BidirectionalToZeroOrOne(
Generic[_OwnerT, _AssociateT],
_ToZeroOrOneAssociation[_OwnerT, _AssociateT],
_BidirectionalAssociation[_OwnerT, _AssociateT],
):
"""
A bidirectional *-to-zero-or-one entity type association.
"""
def __set__(
self,
instance: _OwnerT,
value: _AssociateT
| ToZeroOrOneResolver[_AssociateT]
| ToOneResolver[_AssociateT]
| None,
) -> None:
previous_associate = self.__get__(instance, type(instance))
if previous_associate == value:
return
super().__set__(instance, value)
if previous_associate is not None:
self.inverse().disassociate(previous_associate, instance)
if not isinstance(value, _Resolver) and value is not None:
self.inverse().associate(value, instance)
[docs]
@override
def resolve(self, owner: _OwnerT) -> None:
value = getattr(owner, self._internal_owner_attr_name, None)
if isinstance(value, _Resolver):
associate = value.resolve()
setattr(owner, self._internal_owner_attr_name, value.resolve())
if associate:
self.inverse().associate(associate, owner)
[docs]
class BidirectionalToOne(
Generic[_OwnerT, _AssociateT],
_ToOneAssociation[_OwnerT, _AssociateT],
_BidirectionalAssociation[_OwnerT, _AssociateT],
):
"""
A bidirectional *-to-one entity type association.
"""
[docs]
@override
def resolve(self, owner: _OwnerT) -> None:
value = getattr(owner, self._internal_owner_attr_name, None)
if value is None:
raise AssociationRequired.new(self, owner)
if isinstance(value, _Resolver):
associate = value.resolve()
setattr(owner, self._internal_owner_attr_name, associate)
self.inverse().associate(associate, owner)
def __set__(
self, instance: _OwnerT, value: _AssociateT | ToOneResolver[_AssociateT]
) -> None:
try:
previous_associate = cast(
_AssociateT | None, getattr(self, self._internal_owner_attr_name)
)
except AttributeError:
previous_associate = None
if previous_associate == value:
return
super().__set__(instance, value)
if previous_associate:
self.inverse().disassociate(previous_associate, instance)
if not isinstance(value, _Resolver):
self.inverse().associate(value, instance)
[docs]
class BidirectionalToMany(
Generic[_OwnerT, _AssociateT],
_ToManyAssociation[_OwnerT, _AssociateT],
_BidirectionalAssociation[_OwnerT, _AssociateT],
):
"""
A bidirectional *-to-many entity type association.
"""
@override
def _new_collection(self, instance: _OwnerT) -> EntityCollection[_AssociateT]:
return _BidirectionalAssociateCollection(
instance,
self,
)
[docs]
@final
class UnidirectionalToZeroOrOne(
Generic[_OwnerT, _AssociateT], _ToZeroOrOneAssociation[_OwnerT, _AssociateT]
):
"""
A unidirectional to-zero-or-one entity type association.
"""
[docs]
@override
def resolve(self, owner: _OwnerT) -> None:
value = getattr(owner, self._internal_owner_attr_name, None)
if isinstance(value, _Resolver):
setattr(owner, self._internal_owner_attr_name, value.resolve())
[docs]
@final
class UnidirectionalToOne(
Generic[_OwnerT, _AssociateT], _ToOneAssociation[_OwnerT, _AssociateT]
):
"""
A unidirectional to-one entity type association.
"""
[docs]
@override
def resolve(self, owner: _OwnerT) -> None:
value = getattr(owner, self._internal_owner_attr_name, None)
if value is None:
raise AssociationRequired.new(self, owner)
if isinstance(value, _Resolver):
setattr(owner, self._internal_owner_attr_name, value.resolve())
[docs]
@final
class UnidirectionalToMany(
Generic[_OwnerT, _AssociateT], _ToManyAssociation[_OwnerT, _AssociateT]
):
"""
A unidirectional to-many entity type association.
"""
pass
[docs]
@final
class AssociationRegistry:
"""
Inspect any known entity type associations.
"""
_associations = set[_Association[Any, Any]]()
[docs]
@classmethod
def get_all_associations(cls, owner: type | object) -> set[_Association[Any, Any]]:
"""
Get all associations for an owner.
"""
owner_type = owner if isinstance(owner, type) else type(owner)
return {
association
for association in cls._associations
if association.owner_type in owner_type.__mro__
}
[docs]
@classmethod
def get_association(
cls, owner: type[_OwnerT] | _OwnerT, owner_attr_name: str
) -> _Association[_OwnerT, Any]:
"""
Get the association for a given owner and attribute name.
"""
for association in cls.get_all_associations(owner):
if association.owner_attr_name == owner_attr_name:
return association
raise ValueError(
f"No association exists for {owner if isinstance(owner, type) else owner.__class__}.{owner_attr_name}."
)
@classmethod
def _register(cls, association: _Association[Any, Any]) -> None:
cls._associations.add(association)
class _BidirectionalAssociateCollection(
Generic[_AssociateT, _OwnerT], SingleTypeEntityCollection[_AssociateT]
):
__slots__ = "__owner", "_association"
def __init__(
self,
owner: _OwnerT,
association: _BidirectionalAssociation[_OwnerT, _AssociateT],
):
super().__init__(association.associate_type)
self._association = association
self.__owner = weakref.ref(owner)
@property
def _owner(self) -> _OwnerT:
owner = self.__owner()
assert owner is not None, (
"This associate collection's owner no longer exists in memory."
)
return owner
@override
def _on_add(self, *entities: _AssociateT) -> None:
super()._on_add(*entities)
for associate in entities:
self._association.inverse().associate(associate, self._owner)
@override
def _on_remove(self, *entities: _AssociateT) -> None:
super()._on_remove(*entities)
for associate in entities:
self._association.inverse().disassociate(associate, self._owner)
[docs]
def resolve(*entities: Entity) -> None:
"""
Resolve all entities' associates.
You **MUST** call this on all entities once the resolvers you have set on them can indeed be resolved.
"""
for entity in entities:
for association in AssociationRegistry.get_all_associations(entity):
association.resolve(entity)