"""
Entity collections.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from reprlib import recursive_repr
from typing import (
Generic,
Iterator,
overload,
Any,
Iterable,
cast,
TypeVar,
TYPE_CHECKING,
Self,
)
from typing_extensions import override
from betty.functools import unique
from betty.model import Entity
from betty.repr import repr_instance
if TYPE_CHECKING:
from betty.plugin import PluginIdToTypeMapping
from betty.machine_name import MachineName
from collections.abc import Sequence, MutableSequence, MutableMapping, AsyncIterator
_EntityT = TypeVar("_EntityT", bound=Entity)
_TargetT = TypeVar("_TargetT")
[docs]
class EntityCollection(Generic[_TargetT], ABC):
"""
Provide a collection of entities.
"""
__slots__ = ()
[docs]
def __init__(self):
super().__init__()
def _on_add(self, *entities: _TargetT & Entity) -> None:
pass
def _on_remove(self, *entities: _TargetT & Entity) -> None:
pass
@property
def view(self) -> Sequence[_TargetT & Entity]:
"""
A view of the entities at the time of calling.
"""
return [*self]
[docs]
@abstractmethod
def add(self, *entities: _TargetT & Entity) -> None:
"""
Add the given entities.
"""
pass
[docs]
@abstractmethod
def remove(self, *entities: _TargetT & Entity) -> None:
"""
Remove the given entities.
"""
pass
[docs]
def replace(self, *entities: _TargetT & Entity) -> None:
"""
Replace all entities with the given ones.
"""
self.remove(*(entity for entity in self if entity not in entities))
self.add(*entities)
[docs]
@abstractmethod
def clear(self) -> None:
"""
Clear all entities from the collection.
"""
pass
@abstractmethod
def __iter__(self) -> Iterator[_TargetT & Entity]:
pass
@abstractmethod
def __len__(self) -> int:
pass
@overload
def __getitem__(self, index: int) -> _TargetT & Entity:
pass
@overload
def __getitem__(self, indices: slice) -> Sequence[_TargetT & Entity]:
pass
@abstractmethod
def __getitem__(
self, key: int | slice
) -> _TargetT & Entity | Sequence[_TargetT & Entity]:
pass
@abstractmethod
def __delitem__(self, key: _TargetT & Entity) -> None:
pass
@abstractmethod
def __contains__(self, value: Any) -> bool:
pass
def _known(self, *entities: _TargetT & Entity) -> Iterable[_TargetT & Entity]:
for entity in unique(entities):
if entity in self:
yield entity
def _unknown(self, *entities: _TargetT & Entity) -> Iterable[_TargetT & Entity]:
for entity in unique(entities):
if entity not in self:
yield entity
_EntityCollectionT = TypeVar("_EntityCollectionT", bound=EntityCollection[_EntityT])
[docs]
class SingleTypeEntityCollection(Generic[_TargetT], EntityCollection[_TargetT]):
"""
Collect entities of a single type.
"""
__slots__ = "_entities", "_target_type"
[docs]
def __init__(self, target_type: type[_TargetT], *entities: _TargetT & Entity):
super().__init__()
self._entities: MutableSequence[_TargetT & Entity] = [*entities]
self._target_type = target_type
@override # type: ignore[callable-functiontype]
@recursive_repr()
def __repr__(self) -> str:
return repr_instance(self, target_type=self._target_type, length=len(self))
[docs]
@override
def add(self, *entities: _TargetT & Entity) -> None:
added_entities = [*self._unknown(*entities)]
for entity in added_entities:
self._entities.append(entity)
if added_entities:
self._on_add(*added_entities)
[docs]
@override
def remove(self, *entities: _TargetT & Entity) -> None:
removed_entities = [*self._known(*entities)]
for entity in removed_entities:
self._entities.remove(entity)
if removed_entities:
self._on_remove(*removed_entities)
[docs]
@override
def clear(self) -> None:
self.remove(*self)
@override
def __iter__(self) -> Iterator[_TargetT & Entity]:
return self._entities.__iter__()
@override
def __len__(self) -> int:
return len(self._entities)
@overload
def __getitem__(self, index: int) -> _TargetT & Entity:
pass
@overload
def __getitem__(self, indices: slice) -> Sequence[_TargetT & Entity]:
pass
@overload
def __getitem__(self, entity_id: str) -> _TargetT & Entity:
pass
@override
def __getitem__(
self, key: int | slice | str
) -> _TargetT & Entity | Sequence[_TargetT & Entity]:
if isinstance(key, int):
return self._getitem_by_index(key)
if isinstance(key, slice):
return self._getitem_by_indices(key)
return self._getitem_by_entity_id(key)
def _getitem_by_index(self, index: int) -> _TargetT & Entity:
return self._entities[index]
def _getitem_by_indices(self, indices: slice) -> Sequence[_TargetT & Entity]:
return self.view[indices]
def _getitem_by_entity_id(self, entity_id: str) -> _TargetT & Entity:
for entity in self._entities:
if entity_id == entity.id:
return entity
raise KeyError(
f'Cannot find a {self._target_type} entity with ID "{entity_id}".'
)
@override
def __delitem__(self, key: str | _TargetT & Entity) -> None:
if isinstance(key, self._target_type):
return self._delitem_by_entity(cast("_TargetT & Entity", key))
if isinstance(key, str):
return self._delitem_by_entity_id(key)
raise TypeError(f"Cannot find entities by {repr(key)}.")
def _delitem_by_entity(self, entity: _TargetT & Entity) -> None:
self.remove(entity)
def _delitem_by_entity_id(self, entity_id: str) -> None:
for entity in self._entities:
if entity_id == entity.id:
self.remove(entity)
return
@override
def __contains__(self, value: Any) -> bool:
if isinstance(value, self._target_type):
return self._contains_by_entity(cast("_TargetT & Entity", value))
if isinstance(value, str):
return self._contains_by_entity_id(value)
return False
def _contains_by_entity(self, other_entity: _TargetT & Entity) -> bool:
return any(other_entity is entity for entity in self._entities)
def _contains_by_entity_id(self, entity_id: str) -> bool:
return any(entity.id == entity_id for entity in self._entities)
[docs]
class MultipleTypesEntityCollection(Generic[_TargetT], EntityCollection[_TargetT]):
"""
Collect entities of multiple types.
"""
__slots__ = ("_collections", "_entity_type_id_to_type_mapping")
[docs]
def __init__(
self,
*entities: _TargetT & Entity,
entity_type_id_to_type_mapping: PluginIdToTypeMapping[Entity],
):
super().__init__()
self._entity_type_id_to_type_mapping = entity_type_id_to_type_mapping
self._collections: MutableMapping[
type[Entity], SingleTypeEntityCollection[Entity]
] = {}
self.add(*entities)
[docs]
@classmethod
async def new(cls, *entities: _TargetT & Entity) -> Self:
"""
Create a new instance.
"""
from betty.model import ENTITY_TYPE_REPOSITORY
return cls(
*entities,
entity_type_id_to_type_mapping=await ENTITY_TYPE_REPOSITORY.mapping(),
)
@override # type: ignore[callable-functiontype]
@recursive_repr()
def __repr__(self) -> str:
return repr_instance(
self,
entity_types=", ".join(
entity_type.plugin_id() for entity_type in self._collections
),
length=len(self),
)
def _get_collection(
self, entity_type: type[_EntityT]
) -> SingleTypeEntityCollection[_EntityT]:
assert issubclass(entity_type, Entity), f"{entity_type} is not an entity type."
try:
return cast(
SingleTypeEntityCollection[_EntityT], self._collections[entity_type]
)
except KeyError:
self._collections[entity_type] = SingleTypeEntityCollection(entity_type)
return cast(
SingleTypeEntityCollection[_EntityT], self._collections[entity_type]
)
@overload
def __getitem__(self, index: int) -> _TargetT & Entity:
pass
@overload
def __getitem__(self, indices: slice) -> Sequence[_TargetT & Entity]:
pass
@overload
def __getitem__(
self, entity_type_id: MachineName
) -> SingleTypeEntityCollection[Entity]:
pass
@overload
def __getitem__(
self, entity_type: type[_EntityT]
) -> SingleTypeEntityCollection[_EntityT]:
pass
@override
def __getitem__(
self,
key: int | slice | str | type[_EntityT],
) -> (
_TargetT & Entity
| SingleTypeEntityCollection[Entity]
| SingleTypeEntityCollection[_EntityT]
| Sequence[_TargetT & Entity]
):
if isinstance(key, int):
return self._getitem_by_index(key)
if isinstance(key, slice):
return self._getitem_by_indices(key)
if isinstance(key, str):
return self._getitem_by_entity_type_id(key)
return self._getitem_by_entity_type(key)
def _getitem_by_entity_type(
self, entity_type: type[_EntityT]
) -> SingleTypeEntityCollection[_EntityT]:
return self._get_collection(entity_type)
def _getitem_by_entity_type_id(
self, entity_type_id: MachineName
) -> SingleTypeEntityCollection[Entity]:
return self._get_collection(
self._entity_type_id_to_type_mapping[entity_type_id]
)
def _getitem_by_index(self, index: int) -> _TargetT & Entity:
return self.view[index]
def _getitem_by_indices(self, indices: slice) -> Sequence[_TargetT & Entity]:
return self.view[indices]
@override
def __delitem__(
self, key: str | type[_TargetT & Entity] | _TargetT & Entity
) -> None:
if isinstance(key, type):
return self._delitem_by_entity_type(
key,
)
if isinstance(key, Entity):
return self._delitem_by_entity(
key, # type: ignore[arg-type]
)
return self._delitem_by_entity_type_id(key)
def _delitem_by_entity_type(self, entity_type: type[_TargetT & Entity]) -> None:
removed_entities = [*self._get_collection(entity_type)]
self._get_collection(entity_type).clear()
if removed_entities:
self._on_remove(*removed_entities)
def _delitem_by_entity(self, entity: _TargetT & Entity) -> None:
self.remove(entity)
def _delitem_by_entity_type_id(self, entity_type_id: MachineName) -> None:
self._delitem_by_entity_type(
self._entity_type_id_to_type_mapping[entity_type_id] # type: ignore [arg-type]
)
@override
def __iter__(self) -> Iterator[_TargetT & Entity]:
for collection in self._collections.values():
for entity in collection:
yield cast("_TargetT & Entity", entity)
@override
def __len__(self) -> int:
return sum(map(len, self._collections.values()))
@override
def __contains__(self, value: Any) -> bool:
if isinstance(value, Entity):
return self._contains_by_entity(value)
return False
def _contains_by_entity(self, other_entity: Any) -> bool:
return any(other_entity is entity for entity in self)
[docs]
@override
def add(self, *entities: _TargetT & Entity) -> None:
added_entities = [*self._unknown(*entities)]
for entity in added_entities:
self[entity.type].add(entity)
if added_entities:
self._on_add(*added_entities)
[docs]
@override
def remove(self, *entities: _TargetT & Entity) -> None:
removed_entities = [*self._known(*entities)]
for entity in removed_entities:
self[entity.type].remove(entity)
if removed_entities:
self._on_remove(*removed_entities)
[docs]
@override
def clear(self) -> None:
removed_entities = (*self,)
for collection in self._collections.values():
collection.clear()
if removed_entities:
self._on_remove(*removed_entities)
[docs]
@asynccontextmanager
async def record_added(
entities: EntityCollection[_EntityT],
) -> AsyncIterator[MultipleTypesEntityCollection[_EntityT]]:
"""
Record all entities that are added to a collection.
"""
original = [*entities]
added = await MultipleTypesEntityCollection[_EntityT].new()
yield added
added.add(*[entity for entity in entities if entity not in original])