# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
# Elements of BaseObject reuse code developed in scikit-learn. These elements
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
"""Base class template for objects and fittable objects.
templates in this module:
BaseObject - object with parameters and tags
BaseEstimator - BaseObject that can be fitted
Interface specifications below.
---
class name: BaseObject
Parameter inspection and setter methods
inspect parameter values - get_params()
setting parameter values - set_params(**params)
list of parameter names - get_param_names()
dict of parameter defaults - get_param_defaults()
Tag inspection and setter methods
inspect tags (all) - get_tags()
inspect tags (one tag) - get_tag(tag_name: str, tag_value_default=None)
inspect tags (class method) - get_class_tags()
inspect tags (one tag, class) - get_class_tag(tag_name:str, tag_value_default=None)
setting dynamic tags - set_tag(**tag_dict: dict)
set/clone dynamic tags - clone_tags(estimator, tag_names=None)
Blueprinting: resetting and cloning, post-init state with same hyper-parameters
reset object to post-init - reset()
clone object (copy&reset) - clone()
Testing with default parameters methods
getting default parameters (all sets) - get_test_params()
get one test instance with default parameters - create_test_instance()
get list of all test instances plus name list - create_test_instances_and_names()
---
class name: BaseEstimator
Provides all interface points of BaseObject, plus:
Parameter inspection:
fitted parameter inspection - get_fitted_params()
State:
fitted model/strategy - by convention, any attributes ending in "_"
fitted state flag - is_fitted (property)
fitted state check - check_is_fitted (raises error if not is_fitted)
"""
import inspect
import re
import warnings
from collections import defaultdict
from copy import deepcopy
from typing import List
from skbase._exceptions import NotFittedError
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager
__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"]
__all__: List[str] = ["BaseEstimator", "BaseObject"]
class BaseObject(_FlagManager):
"""Base class for parametric objects with sktime style tag interface.
Extends scikit-learn's BaseEstimator to include sktime style interface for tags.
"""
_config = {
"display": "diagram",
"print_changed_only": True,
"check_clone": False, # whether to execute validity checks in clone
"clone_config": True, # clone config values (True) or use defaults (False)
}
def __init__(self):
"""Construct BaseObject."""
self._init_flags(flag_attr_name="_tags")
self._init_flags(flag_attr_name="_config")
super(BaseObject, self).__init__()
def __eq__(self, other):
"""Equality dunder. Checks equal class and parameters.
Returns True iff result of ``get_params(deep=False)``
results in equal parameter sets.
Nested BaseObject descendants from ``get_params`` are compared via
``__eq__`` as well.
"""
from skbase.utils.deep_equals import deep_equals
if not isinstance(other, BaseObject):
return False
self_params = self.get_params(deep=False)
other_params = other.get_params(deep=False)
return deep_equals(self_params, other_params)
def reset(self):
"""Reset the object to a clean post-init state.
Results in setting ``self`` to the state it had directly
after the constructor call, with the same hyper-parameters.
Config values set by ``set_config`` are also retained.
A ``reset`` call deletes any object attributes, except:
- hyper-parameters = arguments of ``__init__`` written to ``self``,
e.g., ``self.paramname`` where ``paramname`` is an argument of ``__init__``
- object attributes containing double-underscores, i.e., the string "__".
For instance, an attribute named "__myattr" is retained.
- config attributes, configs are retained without change.
That is, results of ``get_config`` before and after ``reset`` are equal.
Class and object methods, and class attributes are also unaffected.
Equivalent to ``clone``, with the exception that ``reset``
mutates ``self`` instead of returning a new object.
After a ``self.reset()`` call,
``self`` is equal in value and state, to the object obtained after
a constructor call``type(self)(**self.get_params(deep=False))``.
Returns
-------
self
Instance of class reset to a clean post-init state but retaining
the current hyper-parameter values.
"""
# retrieve parameters to copy them later
params = self.get_params(deep=False)
config = self.get_config()
# delete all object attributes in self
attrs = [attr for attr in dir(self) if "__" not in attr]
cls_attrs = list(dir(type(self)))
self_attrs = set(attrs).difference(cls_attrs)
for attr in self_attrs:
delattr(self, attr)
# run init with a copy of parameters self had at the start
self.__init__(**params)
self.set_config(**config)
return self
def clone(self):
"""Obtain a clone of the object with same hyper-parameters and config.
A clone is a different object without shared references, in post-init state.
This function is equivalent to returning ``sklearn.clone`` of ``self``.
Equivalent to constructing a new instance of ``type(self)``, with
parameters of ``self``, that is,
``type(self)(**self.get_params(deep=False))``.
If configs were set on ``self``, the clone will also have the same configs
as the original,
equivalent to calling ``cloned_self.set_config(**self.get_config())``.
Also equivalent in value to a call of ``self.reset``,
with the exception that ``clone`` returns a new object,
instead of mutating ``self`` like ``reset``.
Raises
------
RuntimeError if the clone is non-conforming, due to faulty ``__init__``.
"""
self_clone = _clone(self)
if self.get_config()["check_clone"]:
_check_clone(original=self, clone=self_clone)
return self_clone
@classmethod
def _get_init_signature(cls):
"""Get class init signature.
Useful in parameter inspection.
Returns
-------
List
The inspected parameter objects (including defaults).
Raises
------
RuntimeError if ``cls`` has varargs in ``__init__``.
"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
init = getattr(cls.__init__, "deprecated_original", cls.__init__)
if init is object.__init__:
# No explicit constructor to introspect
return []
# introspect the constructor arguments to find the model parameters
# to represent
init_signature = inspect.signature(init)
# Consider the constructor parameters excluding 'self'
parameters = [
p
for p in init_signature.parameters.values()
if p.name != "self" and p.kind != p.VAR_KEYWORD
]
for p in parameters:
if p.kind == p.VAR_POSITIONAL:
raise RuntimeError(
"scikit-base compatible classes should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
" %s with constructor %s doesn't "
" follow this convention." % (cls, init_signature)
)
return parameters
@classmethod
def get_param_names(cls, sort=True):
"""Get object's parameter names.
Parameters
----------
sort : bool, default=True
Whether to return the parameter names sorted in alphabetical order (True),
or in the order they appear in the class ``__init__`` (False).
Returns
-------
param_names: list[str]
List of parameter names of ``cls``.
If ``sort=False``, in same order as they appear in the class ``__init__``.
If ``sort=True``, alphabetically ordered.
"""
if sort is None:
sort = True
parameters = cls._get_init_signature()
param_names = [p.name for p in parameters]
if sort:
param_names = sorted(param_names)
return param_names
@classmethod
def get_param_defaults(cls):
"""Get object's parameter defaults.
Returns
-------
default_dict: dict[str, Any]
Keys are all parameters of ``cls`` that have
a default defined in ``__init__``.
Values are the defaults, as defined in ``__init__``.
"""
parameters = cls._get_init_signature()
default_dict = {
x.name: x.default for x in parameters if x.default != inspect._empty
}
return default_dict
def get_params(self, deep=True):
"""Get a dict of parameters values for this object.
Parameters
----------
deep : bool, default=True
Whether to return parameters of components.
* If ``True``, will return a ``dict`` of
parameter name : value for this object,
including parameters of components (= ``BaseObject``-valued parameters).
* If ``False``, will return a ``dict``
of parameter name : value for this object,
but not include parameters of components.
Returns
-------
params : dict with str-valued keys
Dictionary of parameters, paramname : paramvalue
keys-value pairs include:
* always: all parameters of this object, as via ``get_param_names``
values are parameter value for that key, of this object
values are always identical to values passed at construction
* if ``deep=True``, also contains keys/value pairs of component parameters
parameters of components are indexed as ``[componentname]__[paramname]``
all parameters of ``componentname`` appear as ``paramname`` with its value
* if ``deep=True``, also contains arbitrary levels of component recursion,
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
"""
params = {key: getattr(self, key) for key in self.get_param_names()}
if deep:
deep_params = {}
for key, value in params.items():
if hasattr(value, "get_params"):
deep_items = value.get_params().items()
deep_params.update({f"{key}__{k}": val for k, val in deep_items})
params.update(deep_params)
return params
def set_params(self, **params):
"""Set the parameters of this object.
The method works on simple skbase objects as well as on composite objects.
Parameter key strings ``<component>__<parameter>`` can be used for composites,
i.e., objects that contain other objects, to access ``<parameter>`` in
the component ``<component>``.
The string ``<parameter>``, without ``<component>__``, can also be used if
this makes the reference unambiguous, e.g., there are no two parameters of
components with the name ``<parameter>``.
Parameters
----------
**params : dict
BaseObject parameters, keys must be ``<component>__<parameter>`` strings.
``__`` suffixes can alias full strings, if unique among get_params keys.
Returns
-------
self : reference to self (after parameters have been set)
"""
if not params:
# Simple optimization to gain speed (inspect is slow)
return self
valid_params = self.get_params(deep=True)
unmatched_keys = []
nested_params = defaultdict(dict) # grouped by prefix
for full_key, value in params.items():
# split full_key by first occurrence of __, if contains __
# "key_without_dblunderscore" -> "key_without_dbl_underscore", None, None
# "key__with__dblunderscore" -> "key", "__", "with__dblunderscore"
key, delim, sub_key = full_key.partition("__")
# if key not recognized, remember for suffix matching
if key not in valid_params:
unmatched_keys += [key]
# if full_key contained __, collect suffix for component set_params
elif delim:
nested_params[key][sub_key] = value
# if key is found and did not contain __, set self.key to the value
else:
setattr(self, key, value)
valid_params[key] = value
# all matched params have now been set
# reset object to clean post-init state with those params
self.reset()
# recurse in components
for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)
# for unmatched keys, resolve by aliasing via available __ suffixes, recurse
if len(unmatched_keys) > 0:
valid_params = self.get_params(deep=True)
unmatched_params = {key: params[key] for key in unmatched_keys}
# aliasing, syntactic sugar to access uniquely named params more easily
aliased_params = self._alias_params(unmatched_params, valid_params)
# if none of the parameter names change through aliasing, raise error
if set(aliased_params) == set(unmatched_params):
raise ValueError(
f"Invalid parameter keys provided to set_params of object {self}. "
"Check the list of available parameters "
"with `object.get_params().keys()`. "
f"Invalid keys provided: {unmatched_keys}"
)
# recurse: repeat matching and aliasing until no further matches found
# termination condition is above, "no change in keys via aliasing"
self.set_params(**aliased_params)
return self
def _alias_params(self, d, valid_params):
"""Replace shorthands in d by full keys from valid_params.
Parameters
----------
d: dict with str keys
valid_params: dict with str keys
Result
------
alias_dict: dict with str keys, all keys in valid_params
values are as in d, with keys replaced by following rule:
* If key is a ``__`` suffix of exactly one key in ``valid_params``,
it is replaced by that key. Otherwise an exception is raised.
* A ``__``-suffix of a ``str`` is any ``str`` obtained as suffix
from partition by the string ``"__"``.
Else, i.e., if key is in valid_params or not a ``__``-suffix,
the key is replaced by itself, i.e., left unchanged.
Raises
------
ValueError if at least one key of d is neither contained in ``valid_params``,
nor is it a ``__``-suffix of exactly one key in ``valid_params``
"""
def _is_suffix(x, y):
"""Return whether x is a strict __ suffix of y."""
return y.endswith(x) and y.endswith("__" + x)
def _get_alias(x, d):
"""Return alias of x in d."""
# if key is in valid_params, key is replaced by key (itself)
if any(x == y for y in d.keys()):
return x
suff_list = [y for y in d.keys() if _is_suffix(x, y)]
# if key is a __ suffix of exactly one key in valid_params,
# it is replaced by that key
ns = len(suff_list)
if ns > 1:
raise ValueError(
f"suffix {x} does not uniquely determine parameter key, of "
f"{type(self).__name__} instance"
f"the following parameter keys have the same suffix: {suff_list}"
)
if ns == 0:
return x
# if ns == 1
return suff_list[0]
alias_dict = {_get_alias(x, valid_params): d[x] for x in d.keys()}
return alias_dict
@classmethod
def get_class_tags(cls):
"""Get class tags from class, with tag level inheritance from parent classes.
Every ``scikit-base`` compatible object has a dictionary of tags.
Tags may be used to store metadata about the object,
or to control behaviour of the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object.
The ``get_class_tags`` method is a class method,
and retrieves the value of a tag
taking into account only class-level tag values and overrides.
It returns a dictionary with keys being keys of any attribute of ``_tags``
set in the class or any of its parent classes.
Values are the corresponding tag values, with overrides in the following
order of descending priority:
1. Tags set in the ``_tags`` attribute of the class.
2. Tags set in the ``_tags`` attribute of parent classes,
in order of inheritance.
Instances can override these tags depending on hyper-parameters.
To retrieve tags with potential instance overrides, use
the ``get_tags`` method instead.
Does not take into account dynamic tag overrides on instances,
set via ``set_tags`` or ``clone_tags``,
that are defined on instances.
For including overrides from dynamic tags, use ``get_tags``.
Returns
-------
collected_tags : dict
Dictionary of tag name : tag value pairs. Collected from ``_tags``
class attribute via nested inheritance. NOT overridden by dynamic
tags set by ``set_tags`` or ``clone_tags``.
"""
return cls._get_class_flags(flag_attr_name="_tags")
@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None):
"""Get class tag value from class, with tag level inheritance from parents.
Every ``scikit-base`` compatible object has a dictionary of tags.
Tags may be used to store metadata about the object,
or to control behaviour of the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object.
The ``get_class_tag`` method is a class method, and retrieves the value of a tag
taking into account only class-level tag values and overrides.
It returns the value of the tag with name ``tag_name`` from the object,
taking into account tag overrides, in the following
order of descending priority:
1. Tags set in the ``_tags`` attribute of the class.
2. Tags set in the ``_tags`` attribute of parent classes,
in order of inheritance.
Does not take into account dynamic tag overrides on instances,
set via ``set_tags`` or ``clone_tags``,
that are defined on instances.
To retrieve tag values with potential instance overrides, use
the ``get_tag`` method instead.
Parameters
----------
tag_name : str
Name of tag value.
tag_value_default : any type
Default/fallback value if tag is not found.
Returns
-------
tag_value :
Value of the ``tag_name`` tag in ``self``.
If not found, returns ``tag_value_default``.
"""
return cls._get_class_flag(
flag_name=tag_name,
flag_value_default=tag_value_default,
flag_attr_name="_tags",
)
def get_tags(self):
"""Get tags from instance, with tag level inheritance and overrides.
Every ``scikit-base`` compatible object has a dictionary of tags.
Tags may be used to store metadata about the object,
or to control behaviour of the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object.
The ``get_tags`` method returns a dictionary of tags,
with keys being keys of any attribute of ``_tags``
set in the class or any of its parent classes, or tags set via ``set_tags``
or ``clone_tags``.
Values are the corresponding tag values, with overrides in the following
order of descending priority:
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
at construction of the instance.
2. Tags set in the ``_tags`` attribute of the class.
3. Tags set in the ``_tags`` attribute of parent classes,
in order of inheritance.
Returns
-------
collected_tags : dict
Dictionary of tag name : tag value pairs. Collected from ``_tags``
class attribute via nested inheritance and then any overrides
and new tags from ``_tags_dynamic`` object attribute.
"""
return self._get_flags(flag_attr_name="_tags")
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from instance, with tag level inheritance and overrides.
Every ``scikit-base`` compatible object has a dictionary of tags.
Tags may be used to store metadata about the object,
or to control behaviour of the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object.
The ``get_tag`` method retrieves the value of a single tag
with name ``tag_name`` from the instance,
taking into account tag overrides, in the following
order of descending priority:
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
at construction of the instance.
2. Tags set in the ``_tags`` attribute of the class.
3. Tags set in the ``_tags`` attribute of parent classes,
in order of inheritance.
Parameters
----------
tag_name : str
Name of tag to be retrieved
tag_value_default : any type, optional; default=None
Default/fallback value if tag is not found
raise_error : bool
whether a ``ValueError`` is raised when the tag is not found
Returns
-------
tag_value : Any
Value of the ``tag_name`` tag in ``self``.
If not found, raises an error if
``raise_error`` is True, otherwise it returns ``tag_value_default``.
Raises
------
ValueError, if ``raise_error`` is ``True``.
The ``ValueError`` is then raised if ``tag_name`` is
not in ``self.get_tags().keys()``.
"""
return self._get_flag(
flag_name=tag_name,
flag_value_default=tag_value_default,
raise_error=raise_error,
flag_attr_name="_tags",
)
def set_tags(self, **tag_dict):
"""Set instance level tag overrides to given values.
Every ``scikit-base`` compatible object has a dictionary of tags.
Tags may be used to store metadata about the object,
or to control behaviour of the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object.
``set_tags`` sets dynamic tag overrides
to the values as specified in ``tag_dict``, with keys being the tag name,
and dict values being the value to set the tag to.
The ``set_tags`` method
should be called only in the ``__init__`` method of an object,
during construction, or directly after construction via ``__init__``.
Current tag values can be inspected by ``get_tags`` or ``get_tag``.
Parameters
----------
**tag_dict : dict
Dictionary of tag name: tag value pairs.
Returns
-------
Self
Reference to self.
"""
self._set_flags(flag_attr_name="_tags", **tag_dict)
return self
def clone_tags(self, estimator, tag_names=None):
"""Clone tags from another object as dynamic override.
Every ``scikit-base`` compatible object has a dictionary of tags.
Tags may be used to store metadata about the object,
or to control behaviour of the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object.
``clone_tags`` sets dynamic tag overrides
from another object, ``estimator``.
The ``clone_tags`` method
should be called only in the ``__init__`` method of an object,
during construction, or directly after construction via ``__init__``.
The dynamic tags are set to the values of the tags in ``estimator``,
with the names specified in ``tag_names``.
The default of ``tag_names`` writes all tags from ``estimator`` to ``self``.
Current tag values can be inspected by ``get_tags`` or ``get_tag``.
Parameters
----------
estimator : An instance of :class:BaseObject or derived class
tag_names : str or list of str, default = None
Names of tags to clone.
The default (``None``) clones all tags from ``estimator``.
Returns
-------
self :
Reference to ``self``.
"""
self._clone_flags(
estimator=estimator, flag_names=tag_names, flag_attr_name="_tags"
)
return self
def get_config(self):
"""Get config flags for self.
Configs are key-value pairs of ``self``,
typically used as transient flags for controlling behaviour.
``get_config`` returns dynamic configs, which override the default configs.
Default configs are set in the class attribute ``_config`` of
the class or its parent classes,
and are overridden by dynamic configs set via ``set_config``.
Configs are retained under ``clone`` or ``reset`` calls.
Returns
-------
config_dict : dict
Dictionary of config name : config value pairs. Collected from _config
class attribute via nested inheritance and then any overrides
and new tags from _onfig_dynamic object attribute.
"""
return self._get_flags(flag_attr_name="_config")
def set_config(self, **config_dict):
"""Set config flags to given values.
Configs are key-value pairs of ``self``,
typically used as transient flags for controlling behaviour.
``set_config`` sets dynamic configs, which override the default configs.
Default configs are set in the class attribute ``_config`` of
the class or its parent classes,
and are overridden by dynamic configs set via ``set_config``.
Configs are retained under ``clone`` or ``reset`` calls.
Parameters
----------
config_dict : dict
Dictionary of config name : config value pairs.
Returns
-------
self : reference to self.
Notes
-----
Changes object state, copies configs in config_dict to self._config_dynamic.
"""
self._set_flags(flag_attr_name="_config", **config_dict)
return self
@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the skbase object.
``get_test_params`` is a unified interface point to store
parameter settings for testing purposes. This function is also
used in ``create_test_instance`` and ``create_test_instances_and_names``
to construct test instances.
``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``.
Each ``dict`` is a parameter configuration for testing,
and can be used to construct an "interesting" test instance.
A call to ``cls(**params)`` should
be valid for all dictionaries ``params`` in the return of ``get_test_params``.
The ``get_test_params`` need not return fixed lists of dictionaries,
it can also return dynamic or stochastic parameter settings.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
params_with_defaults = set(cls.get_param_defaults().keys())
all_params = set(cls.get_param_names(sort=False))
params_without_defaults = all_params - params_with_defaults
# if non-default parameters are required, but none have been found, raise error
if len(params_without_defaults) > 0:
raise ValueError(
f"skbase object {cls} has parameters without default values, "
f"but these are not set in get_test_params. "
f"Please set them in get_test_params, or provide default values. "
f"Also see the respective extension template, if applicable."
)
# construct with parameter configuration for testing, otherwise construct with
# default parameters (empty dict)
params = {}
return params
@classmethod
def create_test_instance(cls, parameter_set="default"):
"""Construct an instance of the class, using first test parameter set.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
Returns
-------
instance : instance of the class with default parameters
"""
if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
params = cls.get_test_params(parameter_set=parameter_set)
else:
params = cls.get_test_params()
if isinstance(params, list) and isinstance(params[0], dict):
params = params[0]
elif isinstance(params, dict):
pass
else:
raise TypeError(
"get_test_params should either return a dict or list of dict."
)
return cls._safe_init_test_params(params)
@classmethod
def create_test_instances_and_names(cls, parameter_set="default"):
"""Create list of all test instances and a list of names for them.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
Returns
-------
objs : list of instances of cls
i-th instance is ``cls(**cls.get_test_params()[i])``
names : list of str, same length as objs
i-th element is name of i-th instance of obj in tests.
The naming convention is ``{cls.__name__}-{i}`` if more than one instance,
otherwise ``{cls.__name__}``
"""
if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
param_list = cls.get_test_params(parameter_set=parameter_set)
else:
param_list = cls.get_test_params()
objs = []
if not isinstance(param_list, (dict, list)):
raise RuntimeError(
f"Error in {cls.__name__}.get_test_params, "
"return must be param dict for class, or list thereof"
)
if isinstance(param_list, dict):
param_list = [param_list]
for params in param_list:
if not isinstance(params, dict):
raise RuntimeError(
f"Error in {cls.__name__}.get_test_params, "
"return must be param dict for class, or list thereof"
)
objs += [cls._safe_init_test_params(params)]
num_instances = len(param_list)
if num_instances > 1:
names = [cls.__name__ + "-" + str(i) for i in range(num_instances)]
else:
names = [cls.__name__]
return objs, names
@classmethod
def _safe_init_test_params(cls, params):
"""Safe init of cls with params for testing.
Will raise informative error message if params are not valid.
"""
try:
return cls(**params)
except Exception as e:
raise type(e)(
f"Error in {cls.__name__}.get_test_params, "
"return must be valid param dict for class, or list thereof, "
"but attempted construction raised a exception. "
f"Problematic parameter set: {params}. Exception raised: {e}"
) from e
@classmethod
def _has_implementation_of(cls, method):
"""Check if method has a concrete implementation in this class.
This assumes that having an implementation is equivalent to
one or more overrides of `method` in the method resolution order.
Parameters
----------
method : str, name of method to check implementation of
Returns
-------
bool, whether method has implementation in cls
True if cls.method has been overridden at least once in
the inheritance tree (according to method resolution order)
"""
# walk through method resolution order and inspect methods
# of classes and direct parents, "adjacent" classes in mro
mro = inspect.getmro(cls)
# collect all methods that are not none
methods = [getattr(c, method, None) for c in mro]
methods = [m for m in methods if m is not None]
for i in range(len(methods) - 1):
# the method has been overridden once iff
# at least two of the methods collected are not equal
# equivalently: some two adjacent methods are not equal
overridden = methods[i] != methods[i + 1]
if overridden:
return True
return False
def is_composite(self):
"""Check if the object is composed of other BaseObjects.
A composite object is an object which contains objects, as parameters.
Called on an instance, since this may differ by instance.
Returns
-------
composite: bool
Whether an object has any parameters whose values
are ``BaseObject`` descendant instances.
"""
# walk through method resolution order and inspect methods
# of classes and direct parents, "adjacent" classes in mro
params = self.get_params(deep=False)
composite = any(isinstance(x, BaseObject) for x in params.values())
return composite
def _components(self, base_class=None):
"""Return references to all state changing BaseObject type attributes.
This *excludes* the blue-print-like components passed in the ``__init__``.
Caution: this method returns *references* and not *copies*.
Writing to the reference will change the respective attribute of self.
Parameters
----------
base_class : class, optional, default=None, must be subclass of BaseObject
if not ``None``, sub-sets return dict to only descendants of ``base_class``
Returns
-------
dict with key = attribute name, value = reference to that BaseObject attribute
dict contains all attributes of self that inherit from BaseObjects, and:
whose names do not contain the string "__", e.g., hidden attributes
are not class attributes, and are not hyper-parameters (__init__ args)
"""
if base_class is None:
base_class = BaseObject
if base_class is not None and not inspect.isclass(base_class):
raise TypeError(f"base_class must be a class, but found {type(base_class)}")
if base_class is not None and not issubclass(base_class, BaseObject):
raise TypeError("base_class must be a subclass of BaseObject")
# retrieve parameter names to exclude them later
param_names = self.get_params(deep=False).keys()
# retrieve all attributes that are BaseObject descendants
attrs = [attr for attr in dir(self) if "__" not in attr]
cls_attrs = list(dir(type(self)))
self_attrs = set(attrs).difference(cls_attrs).difference(param_names)
comp_dict = {x: getattr(self, x) for x in self_attrs}
comp_dict = {x: y for (x, y) in comp_dict.items() if isinstance(y, base_class)}
return comp_dict
def __repr__(self, n_char_max: int = 700):
"""Represent class as string.
This follows the scikit-learn implementation for the string representation
of parameterized objects.
Parameters
----------
n_char_max : int
Maximum (approximate) number of non-blank characters to render. This
can be useful in testing.
"""
from skbase.base._pretty_printing._pprint import _BaseObjectPrettyPrinter
n_max_elements_to_show = 30 # number of elements to show in sequences
# use ellipsis for sequences with a lot of elements
pp = _BaseObjectPrettyPrinter(
compact=True,
indent=1,
indent_at_name=True,
n_max_elements_to_show=n_max_elements_to_show,
changed_only=self.get_config()["print_changed_only"],
)
repr_ = pp.pformat(self)
# Use bruteforce ellipsis when there are a lot of non-blank characters
n_nonblank = len("".join(repr_.split()))
if n_nonblank > n_char_max:
lim = n_char_max // 2 # apprx number of chars to keep on both ends
regex = r"^(\s*\S){%d}" % lim
# The regex '^(\s*\S){%d}' matches from the start of the string
# until the nth non-blank character:
# - ^ matches the start of string
# - (pattern){n} matches n repetitions of pattern
# - \s*\S matches a non-blank char following zero or more blanks
left_match = re.match(regex, repr_)
right_match = re.match(regex, repr_[::-1])
left_lim = left_match.end() if left_match is not None else 0
right_lim = right_match.end() if right_match is not None else 0
if "\n" in repr_[left_lim:-right_lim]:
# The left side and right side aren't on the same line.
# To avoid weird cuts, e.g.:
# categoric...ore',
# we need to start the right side with an appropriate newline
# character so that it renders properly as:
# categoric...
# handle_unknown='ignore',
# so we add [^\n]*\n which matches until the next \n
regex += r"[^\n]*\n"
right_match = re.match(regex, repr_[::-1])
right_lim = right_match.end() if right_match is not None else 0
ellipsis = "..."
if left_lim + len(ellipsis) < len(repr_) - right_lim:
# Only add ellipsis if it results in a shorter repr
repr_ = repr_[:left_lim] + "..." + repr_[-right_lim:]
return repr_
@property
def _repr_html_(self):
"""HTML representation of BaseObject.
This is redundant with the logic of `_repr_mimebundle_`. The latter
should be favorted in the long term, `_repr_html_` is only
implemented for consumers who do not interpret `_repr_mimbundle_`.
"""
if self.get_config()["display"] != "diagram":
raise AttributeError(
"_repr_html_ is only defined when the "
"`display` configuration option is set to 'diagram'."
)
return self._repr_html_inner
def _repr_html_inner(self):
"""Return HTML representation of class.
This function is returned by the @property `_repr_html_` to make
`hasattr(BaseObject, "_repr_html_") return `True` or `False` depending
on `self.get_config()["display"]`.
"""
return _object_html_repr(self)
def _repr_mimebundle_(self, **kwargs):
"""Mime bundle used by jupyter kernels to display instances of BaseObject."""
output = {"text/plain": repr(self)}
if self.get_config()["display"] == "diagram":
output["text/html"] = _object_html_repr(self)
return output
def set_random_state(self, random_state=None, deep=True, self_policy="copy"):
"""Set random_state pseudo-random seed parameters for self.
Finds ``random_state`` named parameters via ``self.get_params``,
and sets them to integers derived from ``random_state`` via ``set_params``.
These integers are sampled from chain hashing via ``sample_dependent_seed``,
and guarantee pseudo-random independence of seeded random generators.
Applies to ``random_state`` parameters in ``self``, depending on
``self_policy``, and remaining component objects
if and only if ``deep=True``.
Note: calls ``set_params`` even if ``self`` does not have a ``random_state``,
or none of the components have a ``random_state`` parameter.
Therefore, ``set_random_state`` will reset any ``scikit-base`` object,
even those without a ``random_state`` parameter.
Parameters
----------
random_state : int, RandomState instance or None, default=None
Pseudo-random number generator to control the generation of the random
integers. Pass int for reproducible output across multiple function calls.
deep : bool, default=True
Whether to set the random state in skbase object valued parameters, i.e.,
component estimators.
* If False, will set only ``self``'s ``random_state`` parameter, if exists.
* If True, will set ``random_state`` parameters in component objects
as well.
self_policy : str, one of {"copy", "keep", "new"}, default="copy"
* "copy" : ``self.random_state`` is set to input ``random_state``
* "keep" : ``self.random_state`` is kept as is
* "new" : ``self.random_state`` is set to a new random state,
derived from input ``random_state``, and in general different from it
Returns
-------
self : reference to self
"""
from skbase.utils.random_state import set_random_state
return set_random_state(
self,
random_state=random_state,
deep=deep,
root_policy=self_policy,
)
class TagAliaserMixin:
"""Mixin class for tag aliasing and deprecation of old tags.
To deprecate tags, add the TagAliaserMixin to BaseObject or BaseEstimator.
alias_dict contains the deprecated tags, and supports removal and renaming.
For removal, add an entry "old_tag_name": ""
For renaming, add an entry "old_tag_name": "new_tag_name"
deprecate_dict contains the version number of renaming or removal.
the keys in deprecate_dict should be the same as in alias_dict.
values in deprecate_dict should be strings, the version of removal/renaming.
The class will ensure that new tags alias old tags and vice versa, during
the deprecation period. Informative warnings will be raised whenever the
deprecated tags are being accessed.
When removing tags, ensure to remove the removed tags from this class.
If no tags are deprecated anymore (e.g., all deprecated tags are removed/renamed),
ensure toremove this class as a parent of BaseObject or BaseEstimator.
"""
# dictionary of aliases
# key = old tag; value = new tag, aliased by old tag
# override this in a child class
alias_dict = {"old_tag": "new_tag", "tag_to_remove": ""}
# dictionary of removal version
# key = old tag; value = version in which tag will be removed, as string
deprecate_dict = {"old_tag": "0.12.0", "tag_to_remove": "99.99.99"}
def __init__(self):
"""Construct TagAliaserMixin."""
super(TagAliaserMixin, self).__init__()
@classmethod
def get_class_tags(cls):
"""Get class tags from class, with tag level inheritance from parent classes.
Every ``scikit-base`` compatible object has a dictionary of tags.
Tags may be used to store metadata about the object,
or to control behaviour of the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object.
The ``get_class_tags`` method is a class method,
and retrieves the value of a tag
taking into account only class-level tag values and overrides.
It returns a dictionary with keys being keys of any attribute of ``_tags``
set in the class or any of its parent classes.
Values are the corresponding tag values, with overrides in the following
order of descending priority:
1. Tags set in the ``_tags`` attribute of the class.
2. Tags set in the ``_tags`` attribute of parent classes,
in order of inheritance.
Instances can override these tags depending on hyper-parameters.
To retrieve tags with potential instance overrides, use
the ``get_tags`` method instead.
Does not take into account dynamic tag overrides on instances,
set via ``set_tags`` or ``clone_tags``,
that are defined on instances.
For including overrides from dynamic tags, use ``get_tags``.
collected_tags : dict
Dictionary of tag name : tag value pairs. Collected from ``_tags``
class attribute via nested inheritance. NOT overridden by dynamic
tags set by ``set_tags`` or ``clone_tags``.
"""
collected_tags = super(TagAliaserMixin, cls).get_class_tags()
collected_tags = cls._complete_dict(collected_tags)
return collected_tags
@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None):
"""Get class tag value from class, with tag level inheritance from parents.
Every ``scikit-base`` compatible object has a dictionary of tags,
which are used to store metadata about the object.
The ``get_class_tag`` method is a class method,
and retrieves the value of a tag
taking into account only class-level tag values and overrides.
It returns the value of the tag with name ``tag_name`` from the object,
taking into account tag overrides, in the following
order of descending priority:
1. Tags set in the ``_tags`` attribute of the class.
2. Tags set in the ``_tags`` attribute of parent classes,
in order of inheritance.
Does not take into account dynamic tag overrides on instances,
set via ``set_tags`` or ``clone_tags``,
that are defined on instances.
To retrieve tag values with potential instance overrides, use
the ``get_tag`` method instead.
Parameters
----------
tag_name : str
Name of tag value.
tag_value_default : any type
Default/fallback value if tag is not found.
Returns
-------
tag_value :
Value of the ``tag_name`` tag in ``self``.
If not found, returns ``tag_value_default``.
"""
cls._deprecate_tag_warn([tag_name])
return super(TagAliaserMixin, cls).get_class_tag(
tag_name=tag_name, tag_value_default=tag_value_default
)
def get_tags(self):
"""Get tags from instance, with tag level inheritance and overrides.
Every ``scikit-base`` compatible object has a dictionary of tags.
Tags may be used to store metadata about the object,
or to control behaviour of the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object.
The ``get_tags`` method returns a dictionary of tags,
with keys being keys of any attribute of ``_tags``
set in the class or any of its parent classes, or tags set via ``set_tags``
or ``clone_tags``.
Values are the corresponding tag values, with overrides in the following
order of descending priority:
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
at construction of the instance.
2. Tags set in the ``_tags`` attribute of the class.
3. Tags set in the ``_tags`` attribute of parent classes,
in order of inheritance.
Returns
-------
collected_tags : dict
Dictionary of tag name : tag value pairs. Collected from ``_tags``
class attribute via nested inheritance and then any overrides
and new tags from ``_tags_dynamic`` object attribute.
"""
collected_tags = super(TagAliaserMixin, self).get_tags()
collected_tags = self._complete_dict(collected_tags)
return collected_tags
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from instance, with tag level inheritance and overrides.
Every ``scikit-base`` compatible object has a dictionary of tags.
Tags may be used to store metadata about the object,
or to control behaviour of the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object.
The ``get_tag`` method retrieves the value of a single tag
with name ``tag_name`` from the instance,
taking into account tag overrides, in the following
order of descending priority:
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
at construction of the instance.
2. Tags set in the ``_tags`` attribute of the class.
3. Tags set in the ``_tags`` attribute of parent classes,
in order of inheritance.
Parameters
----------
tag_name : str
Name of tag to be retrieved
tag_value_default : any type, optional; default=None
Default/fallback value if tag is not found
raise_error : bool
whether a ``ValueError`` is raised when the tag is not found
Returns
-------
tag_value : Any
Value of the ``tag_name`` tag in ``self``.
If not found, raises an error if
``raise_error`` is True, otherwise it returns ``tag_value_default``.
Raises
------
ValueError, if ``raise_error`` is ``True``.
The ``ValueError`` is then raised if ``tag_name`` is
not in ``self.get_tags().keys()``.
"""
self._deprecate_tag_warn([tag_name])
return super(TagAliaserMixin, self).get_tag(
tag_name=tag_name,
tag_value_default=tag_value_default,
raise_error=raise_error,
)
def set_tags(self, **tag_dict):
"""Set instance level tag overrides to given values.
Every ``scikit-base`` compatible object has a dictionary of tags,
which are used to store metadata about the object.
Tags are key-value pairs specific to an instance ``self``,
they are static flags that are not changed after construction
of the object. They may be used for metadata inspection,
or for controlling behaviour of the object.
``set_tags`` sets dynamic tag overrides
to the values as specified in ``tag_dict``, with keys being the tag name,
and dict values being the value to set the tag to.
The ``set_tags`` method
should be called only in the ``__init__`` method of an object,
during construction, or directly after construction via ``__init__``.
Current tag values can be inspected by ``get_tags`` or ``get_tag``.
Parameters
----------
**tag_dict : dict
Dictionary of tag name: tag value pairs.
Returns
-------
Self
Reference to self.
"""
self._deprecate_tag_warn(tag_dict.keys())
tag_dict = self._complete_dict(tag_dict)
super(TagAliaserMixin, self).set_tags(**tag_dict)
return self
@classmethod
def _complete_dict(cls, tag_dict):
"""Add all aliased and aliasing tags to the dictionary."""
alias_dict = cls.alias_dict
deprecated_tags = set(tag_dict.keys()).intersection(alias_dict.keys())
new_tags = set(tag_dict.keys()).intersection(alias_dict.values())
if len(deprecated_tags) > 0 or len(new_tags) > 0:
new_tag_dict = deepcopy(tag_dict)
# for all tag strings being set, write the value
# to all tags that could *be aliased by* the string
# and all tags that could be *aliasing* the string
# this way we ensure upwards and downwards compatibility
for old_tag, new_tag in alias_dict.items():
for tag in tag_dict:
if tag == old_tag and new_tag != "":
new_tag_dict[new_tag] = tag_dict[tag]
if tag == new_tag:
new_tag_dict[old_tag] = tag_dict[tag]
return new_tag_dict
else:
return tag_dict
@classmethod
def _deprecate_tag_warn(cls, tags):
"""Print warning message for tag deprecation.
Parameters
----------
tags : list of str
Raises
------
DeprecationWarning for each tag in tags that is aliased by cls.alias_dict
"""
for tag_name in tags:
if tag_name in cls.alias_dict.keys():
version = cls.deprecate_dict[tag_name]
new_tag = cls.alias_dict[tag_name]
msg = f"tag {tag_name!r} will be removed in sktime version {version}"
if new_tag != "":
msg += (
f" and replaced by {new_tag!r}, please use {new_tag!r} instead"
)
else:
msg += ", please remove code that access or sets {tag_name!r}"
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
class BaseEstimator(BaseObject):
"""Base class for estimators with scikit-learn and sktime design patterns.
Extends BaseObject to include basic functionality for fittable estimators.
"""
# tuple of non-BaseObject classes that count as nested objects
# get_fitted_params will retrieve parameters from these, too
# override in descendant class - common choice: BaseEstimator from sklearn
GET_FITTED_PARAMS_NESTING = ()
def __init__(self):
"""Construct BaseEstimator."""
self._is_fitted = False
super().__init__()
@property
def is_fitted(self):
"""Whether ``fit`` has been called.
Inspects object's ``_is_fitted` attribute that should initialize to ``False``
during object construction, and be set to True in calls to an object's
`fit` method.
Returns
-------
bool
Whether the estimator has been `fit`.
"""
if hasattr(self, "_is_fitted"):
return self._is_fitted
else:
return False
def check_is_fitted(self, method_name=None):
"""Check if the estimator has been fitted.
Check if ``_is_fitted`` attribute is present and ``True``.
The ``is_fitted``
attribute should be set to ``True`` in calls to an object's ``fit`` method.
If not, raises a ``NotFittedError``.
Parameters
----------
method_name : str, optional
Name of the method that called this function. If provided, the error
message will include this information.
Raises
------
NotFittedError
If the estimator has not been fitted yet.
"""
if not self.is_fitted:
if method_name is None:
msg = (
f"This instance of {self.__class__.__name__} has not been fitted "
f"yet. Please call `fit` first."
)
else:
msg = (
f"This instance of {self.__class__.__name__} has not been fitted "
f"yet. Please call `fit` before calling `{method_name}`."
)
raise NotFittedError(msg)
def get_fitted_params(self, deep=True):
"""Get fitted parameters.
State required:
Requires state to be "fitted".
Parameters
----------
deep : bool, default=True
Whether to return fitted parameters of components.
* If True, will return a dict of parameter name : value for this object,
including fitted parameters of fittable components
(= BaseEstimator-valued parameters).
* If False, will return a dict of parameter name : value for this object,
but not include fitted parameters of components.
Returns
-------
fitted_params : dict with str-valued keys
Dictionary of fitted parameters, paramname : paramvalue
keys-value pairs include:
* always: all fitted parameters of this object, as via ``get_param_names``
values are fitted parameter value for that key, of this object
* if ``deep=True``, also contains keys/value pairs of component parameters
parameters of components are indexed as ``[componentname]__[paramname]``
all parameters of ``componentname`` appear as ``paramname`` with its value
* if ``deep=True``, also contains arbitrary levels of component recursion,
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
"""
self.check_is_fitted(method_name="get_fitted_params")
# collect non-nested fitted params of self
fitted_params = self._get_fitted_params()
# the rest is only for nested parameters
# so, if deep=False, we simply return here
if not deep:
return fitted_params
# add all nested parameters from components that are skbase BaseEstimator
c_dict = self._components()
for c, comp in c_dict.items():
if isinstance(comp, BaseEstimator) and comp._is_fitted:
c_f_params = comp.get_fitted_params(deep=deep)
c = c.rstrip("_")
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
fitted_params.update(c_f_params)
# add all nested parameters from components that are sklearn estimators
# we do this recursively as we have to reach into nested sklearn estimators
any_components_left_to_process = True
old_new_params = fitted_params
# this loop recursively and iteratively processes components inside components
while any_components_left_to_process:
new_params = {}
for c, comp in old_new_params.items():
if isinstance(comp, self.GET_FITTED_PARAMS_NESTING):
c_f_params = self._get_fitted_params_default(comp)
c = c.rstrip("_")
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
new_params.update(c_f_params)
fitted_params.update(new_params)
old_new_params = new_params.copy()
n_new_params = len(new_params)
any_components_left_to_process = n_new_params > 0
return fitted_params
def _get_fitted_params_default(self, obj=None):
"""Obtain fitted params of object, per sklearn convention.
Extracts a dict with {paramstr : paramvalue} contents,
where paramstr are all string names of "fitted parameters".
A "fitted attribute" of obj is one that ends in "_" but does not start with "_".
"fitted parameters" are names of fitted attributes, minus the "_" at the end.
Parameters
----------
obj : any object, optional, default=self
Returns
-------
fitted_params : dict with str keys
fitted parameters, keyed by names of fitted parameter
"""
obj = obj if obj else self
# default retrieves all self attributes ending in "_"
# and returns them with keys that have the "_" removed
#
# get all attributes ending in "_", exclude any that start with "_" (private)
fitted_params = [
attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_")
]
def getattr_safe(obj, attr):
"""Get attribute of object, safely.
Safe version of getattr, that returns None if attribute does not exist,
or if an exception is raised during getattr.
Also returns a boolean indicating whether the attribute was successfully
retrieved, to distinguish between None value and non-existent attribute,
or exception during getattr.
Parameters
----------
obj : any object
object to get attribute from
attr : str
attribute name to get from obj
Returns
-------
attr : Any
attribute of obj, if it exists and does not raise on getattr;
otherwise None
success : bool
whether the attribute was successfully retrieved
"""
try:
if hasattr(obj, attr):
attr = getattr(obj, attr)
return attr, True
else:
return None, False
except Exception:
return None, False
fitted_param_dict = {}
for p in fitted_params:
attr, success = getattr_safe(obj, p)
if success:
p_name = p[:-1] # remove the "_" at the end to get the parameter name
fitted_param_dict[p_name] = attr
return fitted_param_dict
def _get_fitted_params(self):
"""Get fitted parameters.
private _get_fitted_params, called from get_fitted_params
State required:
Requires state to be "fitted".
Returns
-------
fitted_params : dict with str keys
fitted parameters, keyed by names of fitted parameter
"""
return self._get_fitted_params_default()
# Adapted from sklearn's `_clone_parametrized()`
def _clone(estimator, *, safe=True):
"""Construct a new unfitted estimator with the same parameters.
Clone does a deep copy of the model in an estimator
without actually copying attached data. It returns a new estimator
with the same parameters that has not been fitted on any data.
Parameters
----------
estimator : {list, tuple, set} of estimator instance or a single \
estimator instance
The estimator or group of estimators to be cloned.
safe : bool, default=True
If safe is False, clone will fall back to a deep copy on objects
that are not estimators.
Returns
-------
estimator : object
The deep copy of the input, an estimator if input is an estimator.
Notes
-----
If the estimator's `random_state` parameter is an integer (or if the
estimator doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original estimator will give the exact same
results. Otherwise, *statistical clone* is returned: the clone might
return different results from the original estimator. More details can be
found in :ref:`randomness`.
"""
estimator_type = type(estimator)
if estimator_type is dict:
return {k: _clone(v, safe=safe) for k, v in estimator.items()}
if estimator_type in (list, tuple, set, frozenset):
return estimator_type([_clone(e, safe=safe) for e in estimator])
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
if not safe:
return deepcopy(estimator)
else:
if isinstance(estimator, type):
raise TypeError(
"Cannot clone object. "
+ "You should provide an instance of "
+ "scikit-learn estimator instead of a class."
)
else:
raise TypeError(
"Cannot clone object '%s' (type %s): "
"it does not seem to be a scikit-learn "
"estimator as it does not implement a "
"'get_params' method." % (repr(estimator), type(estimator))
)
klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in new_object_params.items():
new_object_params[name] = _clone(param, safe=False)
new_object = klass(**new_object_params)
params_set = new_object.get_params(deep=False)
# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
param2 = params_set[name]
if param1 is not param2:
raise RuntimeError(
"Cannot clone object %s, as the constructor "
"either does not set or modifies parameter %s" % (estimator, name)
)
# This is an extension to the original sklearn implementation
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
new_object.set_config(**estimator.get_config())
return new_object
def _check_clone(original, clone):
from skbase.utils.deep_equals import deep_equals
self_params = original.get_params(deep=False)
# check that all attributes are written to the clone
for attrname in self_params.keys():
if not hasattr(clone, attrname):
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but {attrname} was not found. "
f"Please check __init__ of {original}."
)
clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
# check equality of parameters post-clone and pre-clone
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
if not clone_attrs_valid:
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but this is not the case. "
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
)