Source code for baselines.compare_confidence

import argparse
import copy
import logging
import math
import os
import pickle
import random
import shutil
import sys
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from pyexpat import model
from tqdm import tqdm

from .basemethod import BaseMethod

sys.path.append("..")
from helpers.metrics import *
from helpers.utils import *

eps_cst = 1e-8


[docs]class CompareConfidence(BaseMethod): """Method trains classifier indepedently on cross entropy, and expert model on whether human prediction is equal to ground truth. Then, at each test point we compare the confidence of the classifier and the expert model. """ def __init__(self, model_class, model_expert, device, plotting_interval=100): """ Args: model_class (pytorch model): _description_ model_expert (pytorch model): _description_ device (str): device plotting_interval (int, optional): _description_. Defaults to 100. """ self.model_class = model_class self.model_expert = model_expert self.device = device self.plotting_interval = plotting_interval
[docs] def fit_epoch_class(self, dataloader, optimizer, verbose=True, epoch=1): """ train classifier for single epoch Args: dataloader (dataloader): _description_ optimizer (optimizer): _description_ verbose (bool, optional): to print loss or not. Defaults to True. epoch (int, optional): _description_. Defaults to 1. """ batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() end = time.time() loss_fn = nn.CrossEntropyLoss() self.model_class.train() for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) outputs = self.model_class(data_x) # cross entropy loss loss = F.cross_entropy(outputs, data_y) optimizer.zero_grad() loss.backward() optimizer.step() prec1 = accuracy(outputs.data, data_y, topk=(1,))[0] losses.update(loss.data.item(), data_x.size(0)) top1.update(prec1.item(), data_x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if torch.isnan(loss): print("Nan loss") logging.warning(f"NAN LOSS") break if verbose and batch % self.plotting_interval == 0: logging.info( "Epoch: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})".format( epoch, batch, len(dataloader), batch_time=batch_time, loss=losses, top1=top1, ) )
[docs] def fit_epoch_expert(self, dataloader, optimizer, verbose=True, epoch=1): """train expert model for single epoch Args: dataloader (_type_): _description_ optimizer (_type_): _description_ verbose (bool, optional): _description_. Defaults to True. epoch (int, optional): _description_. Defaults to 1. """ batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() end = time.time() loss_fn = nn.CrossEntropyLoss() self.model_expert.train() for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) hum_preds = hum_preds.to(self.device) hum_equal_to_y = (hum_preds == data_y).long() hum_equal_to_y = torch.cuda.LongTensor(hum_equal_to_y) outputs = self.model_expert(data_x) # cross entropy loss loss = loss_fn(outputs, hum_equal_to_y) optimizer.zero_grad() loss.backward() optimizer.step() prec1 = accuracy(outputs.data, hum_equal_to_y, topk=(1,))[0] losses.update(loss.data.item(), data_x.size(0)) top1.update(prec1.item(), data_x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if torch.isnan(loss): print("Nan loss") logging.warning(f"NAN LOSS") break if verbose and batch % self.plotting_interval == 0: logging.info( "Epoch: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})".format( epoch, batch, len(dataloader), batch_time=batch_time, loss=losses, top1=top1, ) )
[docs] def fit( self, dataloader_train, dataloader_val, dataloader_test, epochs, optimizer, lr, scheduler=None, verbose=True, test_interval=5, ): """fits classifier and expert model Args: dataloader_train (_type_): train dataloader dataloader_val (_type_): val dataloader dataloader_test (_type_): _description_ epochs (_type_): training epochs optimizer (_type_): optimizer function lr (_type_): learning rate scheduler (_type_, optional): scheduler function. Defaults to None. verbose (bool, optional): _description_. Defaults to True. test_interval (int, optional): _description_. Defaults to 5. Returns: dict: metrics on the test set """ optimizer_class = optimizer(self.model_class.parameters(), lr=lr) optimizer_expert = optimizer(self.model_expert.parameters(), lr=lr) if scheduler is not None: scheduler_class = scheduler(optimizer_class, len(dataloader_train) * epochs) scheduler_expert = scheduler( optimizer_expert, len(dataloader_train) * epochs ) for epoch in tqdm(range(epochs)): self.fit_epoch_class( dataloader_train, optimizer_class, verbose=verbose, epoch=epoch ) self.fit_epoch_expert( dataloader_train, optimizer_expert, verbose=verbose, epoch=epoch ) if verbose and epoch % test_interval == 0: logging.info(compute_deferral_metrics(self.test(dataloader_val))) if scheduler is not None: scheduler_class.step() scheduler_expert.step() return compute_deferral_metrics(self.test(dataloader_test))
[docs] def test(self, dataloader): defers_all = [] truths_all = [] hum_preds_all = [] predictions_all = [] # classifier only rej_score_all = [] # rejector probability class_probs_all = [] # classifier probability self.model_expert.eval() self.model_class.eval() with torch.no_grad(): for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) hum_preds = hum_preds.to(self.device) outputs_class = self.model_class(data_x) outputs_class = F.softmax(outputs_class, dim=1) outputs_expert = self.model_expert(data_x) outputs_expert = F.softmax(outputs_expert, dim=1) max_class_probs, predicted_class = torch.max(outputs_class.data, 1) class_probs_all.extend(outputs_class.cpu().numpy()) predictions_all.extend(predicted_class.cpu().numpy()) truths_all.extend(data_y.cpu().numpy()) hum_preds_all.extend(hum_preds.cpu().numpy()) defers = [] for i in range(len(data_y)): rej_score_all.extend( [outputs_expert[i, 1].item() - max_class_probs[i].item()] ) if outputs_expert[i, 1] > max_class_probs[i]: defers.extend([1]) else: defers.extend([0]) defers_all.extend(defers) # convert to numpy defers_all = np.array(defers_all) truths_all = np.array(truths_all) hum_preds_all = np.array(hum_preds_all) predictions_all = np.array(predictions_all) rej_score_all = np.array(rej_score_all) class_probs_all = np.array(class_probs_all) data = { "defers": defers_all, "labels": truths_all, "hum_preds": hum_preds_all, "preds": predictions_all, "rej_score": rej_score_all, "class_probs": class_probs_all, } return data