Source code for caliber.binary_classification.metrics.focal_loss

import numpy as np


[docs] def focal_loss(targets: np.ndarray, probs: np.ndarray, gamma: float = 2.0) -> float: new_probs = np.clip(probs, 1e-6, 1 - 1e-6) new_probs = np.where(targets == 1, new_probs, 1 - new_probs) return -np.sum((1 - new_probs) ** gamma * np.log(new_probs))