Source code for caliber.multiclass_classification.pred_from_probs_mixin
import abc
import numpy as np
from caliber.multiclass_classification.checks_mixin import (
MulticlassClassificationChecksMixin,
)
class PredFromProbsMulticlassClassificationMixin(MulticlassClassificationChecksMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@abc.abstractmethod
def predict_proba(self, probs: np.ndarray) -> np.ndarray:
pass
def predict(self, probs: np.ndarray) -> np.ndarray:
return np.argmax(self.predict_proba(probs), 1)