from typing import Optional
import numpy as np
from caliber.binary_classification.metrics.precrec_error import precision_recall_error
from caliber.binary_classification.minimizing.linear_scaling.calibration.base import (
CalibrationLinearScalingBinaryClassificationModel,
)
from caliber.binary_classification.minimizing.linear_scaling.mixins.fit.brute_fit import (
LinearScalingBruteFitBinaryClassificationMixin,
)
from caliber.binary_classification.utils.knee_point import knee_point
[docs]
class PrecisionRecallLinearScalingBinaryClassificationModel(
LinearScalingBruteFitBinaryClassificationMixin,
CalibrationLinearScalingBinaryClassificationModel,
):
def __init__(
self,
threshold: Optional[float] = None,
lam: float = 0.01,
minimize_options: Optional[dict] = None,
has_intercept: bool = True,
n_thresholds: int = 100,
):
super().__init__(
loss_fn=self._precision_recall_error,
minimize_options=minimize_options,
has_intercept=has_intercept,
has_bivariate_slope=False,
)
self._lam = lam
if threshold is None:
self._tune_threshold = True
else:
self._tune_threshold = True
self.threshold = threshold
self._n_thresholds = n_thresholds
def _precision_recall_error(self, targets: np.ndarray, probs: np.ndarray) -> float:
self._maybe_update_threshold(probs, targets)
return precision_recall_error(targets, probs, self.threshold)
def _maybe_update_threshold(self, probs: np.ndarray, targets: np.ndarray) -> None:
if self._tune_threshold:
self.threshold = knee_point(probs, targets, self._n_thresholds)[2]