"""
This module houses the ``DetectronEnsemble`` class, responsible for managing the Constrained Disagreement Classifiers (CDCs) ensemble.
It coordinates the training and evaluation of multiple CDCs, aiming to disagree with the predictions of a primary base model under specified conditions.
The ensemble leverages a base model, provided by ``BaseModelManager``, to generate models that are designed to systematically disagree with it in a controlled fashion.
"""
import numpy as np
import copy
from tqdm import tqdm
from MED3pa.datasets import DatasetsManager
from MED3pa.models.base import BaseModelManager
from .record import DetectronRecordsManager
from .stopper import EarlyStopper
[docs]
class DetectronEnsemble:
"""
Manages the constrained disagreement classifiers (CDCs) ensemble, designed to disagree with the base model
under specific conditions. This class facilitates the training and evaluation of multiple CDCs, with a focus
on generating models that systematically challenge the predictions of a primary base model.
"""
def __init__(self, base_model_manager: BaseModelManager, ens_size):
"""
Initializes the Detectron Ensemble with a specified base model manager and ensemble size.
Args:
base_model_manager (BaseModelManager): The manager for handling the base model operations, responsible
for training, prediction, and general management of the base model.
ens_size (int): The number of CDCs in the ensemble. This does not include the base model itself.
Attributes:
base_model_manager (BaseModelManager): Instance of BaseModelManager that manages the operations of the base model.
base_model (Model): The actual base model instance retrieved from the model manager.
ens_size (int): Number of CDC models in the ensemble.
cdcs (list of Model): List containing clones of the base model, which are used as CDCs in the ensemble.
"""
self.base_model_manager = base_model_manager
self.base_model = base_model_manager.get_instance()
self.ens_size = ens_size
self.cdcs = [base_model_manager.clone_base_model() for _ in range(ens_size)]
[docs]
def evaluate_ensemble(self,
datasets : DatasetsManager,
n_runs : int,
samples_size : int ,
training_params : dict,
set : str = 'reference',
patience : int = 3,
allow_margin : bool = False,
margin : int = None,
sampling : str = "uniform"):
"""
Trains the CDCs ensemble to disagree with the base model on a subset of data present in datasets. This process
is repeated for a specified number of runs, each using a different sample of the data.
Args:
datasets (DatasetsManager): Holds the datasets used for training and validation of the base model, as well as the
reference and testing sets for the Detectron.
n_runs (int): Number of runs to train the ensemble. Each run uses a new random sample of data points.
sample_size (int): Number of points to use in each run.
training_params (dict): Additional parameters to use for training the ensemble models.
set (str, optional): Specifies the dataset used for training the ensemble. Options are 'reference' or 'testing'.
Default is 'reference'.
patience (int, optional): The number of consecutive updates without improvement to wait before early stopping.
Default is 3.
allow_margin (bool, optional): Whether to use a probability margin to refine the disagreement. Default is False.
margin (float, optional): The margin threshold above which disagreements in probabilities between the base model
and ensemble are considered significant, if allow_margin is True.
sampling (str, optional): Specifies the method for sampling the data, by default set to 'uniform'.
Returns:
DetectronRecordsManager: The records manager containing all the evaluation records from the ensemble runs.
Raises:
ValueError: If the specified set is neither 'reference' nor 'testing'.
"""
# set up the training, validation and testing sets
training_data = datasets.get_dataset_by_type(dataset_type="training", return_instance=True)
validation_data = datasets.get_dataset_by_type(dataset_type="validation", return_instance=True)
if set=='reference':
testing_data = datasets.get_dataset_by_type(dataset_type="reference", return_instance=True)
elif set == 'testing':
testing_data = datasets.get_dataset_by_type(dataset_type="testing", return_instance=True)
else:
raise ValueError("The set used to evaluate the ensemble must be either the reference set or the testing set")
# set up the records manager
record = DetectronRecordsManager(sample_size=samples_size)
model_evaluation = self.base_model.evaluate(testing_data.get_observations(), testing_data.get_true_labels(), ['Auc', 'Accuracy'])
record.set_evaluation(model_evaluation)
# evaluate the ensemble for n_runs of runs
for seed in tqdm(range(n_runs), desc='running seeds'):
# sample the testing set according to the provided sample_size and current seed
if sampling == "uniform":
testing_set = testing_data.sample_uniform(samples_size, seed)
elif sampling =="random":
testing_set = testing_data.sample_random(samples_size, seed)
else:
raise ValueError("Available sampling methods are: 'uniform' or 'random'.")
# predict probabilities using the base model on the testing set
base_model_pred_probs = self.base_model.predict(testing_set.get_observations(), True)
# set pseudo probabilities and pseudo labels predicted by the base model
testing_set.set_pseudo_probs_labels(base_model_pred_probs, 0.5)
cloned_testing_set = testing_set.clone()
# the base model is always the model with id = 0
model_id = 0
# seed the record
record.seed(seed)
# update the record with the results of the base model
record.update(val_data_x=validation_data.get_observations(), val_data_y=validation_data.get_true_labels(),
sample_size=samples_size, model=self.base_model, model_id=model_id,
predicted_probabilities=testing_set.get_pseudo_probabilities(),
test_data_x=testing_set.get_observations(), test_data_y=testing_set.get_true_labels())
# set up the Early stopper
stopper = EarlyStopper(patience=patience, mode='min')
stopper.update(samples_size)
# Initialize the updated count
updated_count = samples_size
# Train the cdcs
for i in range(1, self.ens_size + 1):
# get the current cdc
cdc = self.cdcs[i-1]
# save the model id
model_id = i
# update the training params with the current seed which is the model id
cdc_training_params = copy.deepcopy(training_params)
if cdc_training_params is not None :
cdc_training_params.update({'seed': i})
else:
cdc_training_params={'seed': i}
# train this cdc to disagree
cdc.train_to_disagree(x_train=training_data.get_observations(), y_train=training_data.get_true_labels(),
x_validation=validation_data.get_observations(), y_validation=validation_data.get_true_labels(),
x_test=testing_set.get_observations(), y_test=testing_set.get_pseudo_labels(),
training_parameters=cdc_training_params,
balance_train_classes=True,
N=updated_count)
# predict probabilities using this cdc
cdc_probabilities = cdc.predict(testing_set.get_observations(), True)
cdc_probabilities_original_set = cdc.predict(cloned_testing_set.get_observations(), True)
# deduct the predictions of this cdc
cdc_predicitons = cdc_probabilities >= 0.5
# calculate the mask to refine the testing set
mask = (cdc_predicitons == testing_set.get_pseudo_labels())
# If margin is specified and there are disagreements, check if the probabilities are significatly different
if allow_margin and not np.all(mask):
# convert to disagreement mask
disagree_mask = ~mask
# calculate the difference between cdc probs and bm probs
prob_diff = np.abs(testing_set.get_pseudo_probabilities() - cdc_probabilities)
# in the disagreement mask, keep only the data point where the probability difference is greater than the margin, only for disagreed on points
refine_mask = (prob_diff < margin) & disagree_mask
# update the mask according to the refine_mask array
mask[refine_mask] = True
# refine the testing set using the mask
updated_count = testing_set.refine(mask)
# log the results for this model
record.update(val_data_x=validation_data.get_observations(), val_data_y=validation_data.get_true_labels(),
sample_size=updated_count, predicted_probabilities=cdc_probabilities_original_set,
model=cdc, model_id=model_id)
# break if no more data
if updated_count == 0:
break
if stopper.update(updated_count):
# print(f'Early stopping: Converged after {i} models')
break
record.sampling_counts = testing_data.get_sample_counts()
record.freeze()
return record