Source code for caliber.binary_classification.iterative_fitting.base

import logging
from typing import Callable, Optional

import numpy as np
from scipy.optimize import minimize
from sklearn.metrics import log_loss

from caliber.binary_classification.base import AbstractBinaryClassificationModel


[docs] class IterativeFittingBinaryClassificationModel(AbstractBinaryClassificationModel): def __init__( self, max_rounds: int = 1000, split: float = 0.8, seed: int = 0, loss_fn: Callable[[np.ndarray, np.ndarray], float] = log_loss, early_stopping_loss_fn: Callable[[np.ndarray, np.ndarray], float] = log_loss, n_bins: int = 10, ): super().__init__() self.max_rounds = max_rounds self.split = split self.loss_fn = loss_fn self.early_stopping_loss_fn = early_stopping_loss_fn self._rng = np.random.default_rng(seed) self.n_bins = n_bins self._bin_edges = None
[docs] def fit( self, probs: np.ndarray, targets: np.ndarray, groups: Optional[np.ndarray] = None, ) -> dict: n_data = len(probs) perm = self._rng.choice(n_data, n_data, replace=False) calib_size = int(np.ceil(n_data * self.split)) calib_probs, val_probs = probs[perm[:calib_size]], probs[perm[calib_size:]] calib_targets, val_targets = ( targets[perm[:calib_size]], targets[perm[calib_size:]], ) calib_groups, val_groups = ( groups[perm[:calib_size]], groups[perm[calib_size:]], ) self._bin_edges = self._get_bin_edges() calib_features = self._get_features(calib_probs, calib_groups) n_features = calib_features.shape[1] self._params = [] val_losses = [self.early_stopping_loss_fn(val_targets, val_probs)] def loss_fn(_params: np.ndarray) -> float: return self.loss_fn( calib_targets, self._predict_proba(_params, calib_probs, calib_features) ) for t in range(self.max_rounds): params = minimize(loss_fn, np.zeros(n_features)).x val_features = self._get_features(val_probs, val_groups) val_probs = self._predict_proba(params, val_probs, val_features) val_losses.append(self.early_stopping_loss_fn(val_targets, val_probs)) if val_losses[-1] >= val_losses[-2]: logging.info( f"Early stopping triggered after {t} rounds. The loss started increasing on the validation data." ) break calib_probs = self._predict_proba(params, calib_probs, calib_features) calib_features = self._get_features(calib_probs, calib_groups) self._params.append(params) return dict(n_iter=len(val_losses) - 1, val_losses=val_losses)
[docs] def predict_proba( self, probs: np.ndarray, groups: Optional[np.ndarray] = None ) -> np.ndarray: if self._params is None: raise ValueError("Run `fit` first.") probs = np.copy(probs) for params in self._params: features = self._get_features(probs, groups) probs = self._predict_proba(params, probs, features) return probs
[docs] def predict( self, probs: np.ndarray, groups: Optional[np.ndarray] = None ) -> np.ndarray: return (self.predict_proba(probs, groups) > 0.5).astype(int)
@staticmethod def _predict_proba( params: np.ndarray, probs: np.ndarray, features: np.ndarray ) -> np.ndarray: return np.clip(probs + np.dot(features, params), 0, 1) def _get_features( self, probs: np.ndarray, groups: Optional[np.ndarray] = None, ): bin_indices = np.digitize(probs, self._bin_edges) features = np.stack( [bin_indices == i for i in range(1, self.n_bins + 2)], axis=1 ) if groups is not None: features = np.concatenate([features * g[:, None] for g in groups.T], axis=1) return features def _get_bin_edges(self): return np.linspace(0, 1, self.n_bins + 1)