Source code for caliber.binary_classification.minimizing.linear_scaling.calibration.focal_linear_scaling

from functools import partial
from typing import Optional

from caliber.binary_classification.metrics.focal_loss import focal_loss
from caliber.binary_classification.minimizing.linear_scaling.calibration.base import (
    CalibrationLinearScalingBinaryClassificationModel,
)
from caliber.binary_classification.minimizing.linear_scaling.mixins.fit.smooth_fit import (
    LinearScalingSmoothFitBinaryClassificationMixin,
)


[docs] class FocalLinearScalingBinaryClassificationModel( LinearScalingSmoothFitBinaryClassificationMixin, CalibrationLinearScalingBinaryClassificationModel, ): def __init__( self, minimize_options: Optional[dict] = None, has_intercept: bool = True, has_bivariate_slope: bool = False, gamma: float = 2.0, ): super().__init__( loss_fn=partial(focal_loss, gamma=gamma), minimize_options=minimize_options, has_intercept=has_intercept, has_bivariate_slope=has_bivariate_slope, )