docs for muutils v0.8.6
View Source on GitHub

muutils.misc.classes


 1from __future__ import annotations
 2
 3from typing import (
 4    Iterable,
 5    Any,
 6    Protocol,
 7    ClassVar,
 8    runtime_checkable,
 9)
10
11from muutils.misc.sequence import flatten
12
13
14def is_abstract(cls: type) -> bool:
15    """
16    Returns if a class is abstract.
17    """
18    if not hasattr(cls, "__abstractmethods__"):
19        return False  # an ordinary class
20    elif len(cls.__abstractmethods__) == 0:
21        return False  # a concrete implementation of an abstract class
22    else:
23        return True  # an abstract class
24
25
26def get_all_subclasses(class_: type, include_self=False) -> set[type]:
27    """
28    Returns a set containing all child classes in the subclass graph of `class_`.
29    I.e., includes subclasses of subclasses, etc.
30
31    # Parameters
32    - `include_self`: Whether to include `class_` itself in the returned set
33    - `class_`: Superclass
34
35    # Development
36    Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic.
37    It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
38    """
39    subs: set[type] = set(
40        flatten(
41            get_all_subclasses(sub, include_self=True)
42            for sub in class_.__subclasses__()
43            if sub is not None
44        )
45    )
46    if include_self:
47        subs.add(class_)
48    return subs
49
50
51def isinstance_by_type_name(o: object, type_name: str):
52    """Behaves like stdlib `isinstance` except it accepts a string representation of the type rather than the type itself.
53    This is a hacky function intended to circumvent the need to import a type into a module.
54    It is susceptible to type name collisions.
55
56    # Parameters
57    `o`: Object (not the type itself) whose type to interrogate
58    `type_name`: The string returned by `type_.__name__`.
59    Generic types are not supported, only types that would appear in `type_.__mro__`.
60    """
61    return type_name in {s.__name__ for s in type(o).__mro__}
62
63
64# dataclass magic
65# --------------------------------------------------------------------------------
66
67
68@runtime_checkable
69class IsDataclass(Protocol):
70    # Generic type for any dataclass instance
71    # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass
72    __dataclass_fields__: ClassVar[dict[str, Any]]
73
74
75def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]:
76    """Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself.
77    The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical.
78    Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen.
79    """
80    return *(
81        getattr(dc, fld.name)
82        for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values())
83    ), type(dc)
84
85
86def dataclass_set_equals(
87    coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass]
88) -> bool:
89    """Compares 2 collections of dataclass instances as if they were sets.
90    Duplicates are ignored in the same manner as a set.
91    Unfrozen dataclasses can't be placed in sets since they're not hashable.
92    Collections of them may be compared using this function.
93    """
94
95    return {get_hashable_eq_attrs(x) for x in coll1} == {
96        get_hashable_eq_attrs(y) for y in coll2
97    }

def is_abstract(cls: type) -> bool:
15def is_abstract(cls: type) -> bool:
16    """
17    Returns if a class is abstract.
18    """
19    if not hasattr(cls, "__abstractmethods__"):
20        return False  # an ordinary class
21    elif len(cls.__abstractmethods__) == 0:
22        return False  # a concrete implementation of an abstract class
23    else:
24        return True  # an abstract class

Returns if a class is abstract.

def get_all_subclasses(class_: type, include_self=False) -> set[type]:
27def get_all_subclasses(class_: type, include_self=False) -> set[type]:
28    """
29    Returns a set containing all child classes in the subclass graph of `class_`.
30    I.e., includes subclasses of subclasses, etc.
31
32    # Parameters
33    - `include_self`: Whether to include `class_` itself in the returned set
34    - `class_`: Superclass
35
36    # Development
37    Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic.
38    It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
39    """
40    subs: set[type] = set(
41        flatten(
42            get_all_subclasses(sub, include_self=True)
43            for sub in class_.__subclasses__()
44            if sub is not None
45        )
46    )
47    if include_self:
48        subs.add(class_)
49    return subs

Returns a set containing all child classes in the subclass graph of class_. I.e., includes subclasses of subclasses, etc.

Parameters

  • include_self: Whether to include class_ itself in the returned set
  • class_: Superclass

Development

Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic. It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.

def isinstance_by_type_name(o: object, type_name: str):
52def isinstance_by_type_name(o: object, type_name: str):
53    """Behaves like stdlib `isinstance` except it accepts a string representation of the type rather than the type itself.
54    This is a hacky function intended to circumvent the need to import a type into a module.
55    It is susceptible to type name collisions.
56
57    # Parameters
58    `o`: Object (not the type itself) whose type to interrogate
59    `type_name`: The string returned by `type_.__name__`.
60    Generic types are not supported, only types that would appear in `type_.__mro__`.
61    """
62    return type_name in {s.__name__ for s in type(o).__mro__}

Behaves like stdlib isinstance except it accepts a string representation of the type rather than the type itself. This is a hacky function intended to circumvent the need to import a type into a module. It is susceptible to type name collisions.

Parameters

o: Object (not the type itself) whose type to interrogate type_name: The string returned by type_.__name__. Generic types are not supported, only types that would appear in type_.__mro__.

@runtime_checkable
class IsDataclass(typing.Protocol):
69@runtime_checkable
70class IsDataclass(Protocol):
71    # Generic type for any dataclass instance
72    # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass
73    __dataclass_fields__: ClassVar[dict[str, Any]]

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto[T](Protocol):
    def meth(self) -> T:
        ...
IsDataclass(*args, **kwargs)
1767def _no_init_or_replace_init(self, *args, **kwargs):
1768    cls = type(self)
1769
1770    if cls._is_protocol:
1771        raise TypeError('Protocols cannot be instantiated')
1772
1773    # Already using a custom `__init__`. No need to calculate correct
1774    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1775    if cls.__init__ is not _no_init_or_replace_init:
1776        return
1777
1778    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1779    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1780    # searches for a proper new `__init__` in the MRO. The new `__init__`
1781    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1782    # instantiation of the protocol subclass will thus use the new
1783    # `__init__` and no longer call `_no_init_or_replace_init`.
1784    for base in cls.__mro__:
1785        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1786        if init is not _no_init_or_replace_init:
1787            cls.__init__ = init
1788            break
1789    else:
1790        # should not happen
1791        cls.__init__ = object.__init__
1792
1793    cls.__init__(self, *args, **kwargs)
def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[typing.Any]:
76def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]:
77    """Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself.
78    The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical.
79    Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen.
80    """
81    return *(
82        getattr(dc, fld.name)
83        for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values())
84    ), type(dc)

Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen.

def dataclass_set_equals( coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass]) -> bool:
87def dataclass_set_equals(
88    coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass]
89) -> bool:
90    """Compares 2 collections of dataclass instances as if they were sets.
91    Duplicates are ignored in the same manner as a set.
92    Unfrozen dataclasses can't be placed in sets since they're not hashable.
93    Collections of them may be compared using this function.
94    """
95
96    return {get_hashable_eq_attrs(x) for x in coll1} == {
97        get_hashable_eq_attrs(y) for y in coll2
98    }

Compares 2 collections of dataclass instances as if they were sets. Duplicates are ignored in the same manner as a set. Unfrozen dataclasses can't be placed in sets since they're not hashable. Collections of them may be compared using this function.