Source code for datasetsdefer.generic_dataset

import logging
import sys

import numpy as np
import torch
from .basedataset import BaseDataset

sys.path.append("../")
from PIL import Image
from torch.utils.data import Dataset
logging.getLogger("PIL").setLevel(logging.WARNING)


[docs]class GenericImageExpertDataset(Dataset): def __init__(self, images, targets, expert_preds, transforms_fn, to_open=False): """ Args: images (list): List of images targets (list): List of labels expert_preds (list): List of expert predictions transforms_fn (function): Function to apply to images to_open (bool): Whether to open images or not (RGB reader) """ self.images = images self.targets = np.array(targets) self.expert_preds = np.array(expert_preds) self.transforms_fn = transforms_fn self.to_open = to_open def __getitem__(self, index): """Take the index of item and returns the image, label, expert prediction and index in original dataset""" label = self.targets[index] if self.transforms_fn is not None and self.to_open: image_paths = self.images[index] image = Image.open(image_paths).convert("RGB") image = self.transforms_fn(image) elif self.transforms_fn is not None: image = self.transforms_fn(self.images[index]) else: image = self.images[index] expert_pred = self.expert_preds[index] return torch.FloatTensor(image), label, expert_pred def __len__(self): return len(self.targets)
[docs]class GenericDatasetDeferral(BaseDataset): def __init__( self, data_train, data_test=None, test_split=0.2, val_split=0.1, batch_size=100, transforms=None, ): """ data_train: training data expectd as dict with keys 'data_x', 'data_y', 'hum_preds' data_test: test data expectd as dict with keys 'data_x', 'data_y', 'hum_preds' test_split: fraction of training data to use for test val_split: fraction of training data to use for validation batch_size: batch size for dataloaders transforms: transforms to apply to images """ self.data_train = data_train self.data_test = data_test self.test_split = test_split self.val_split = val_split self.batch_size = batch_size self.train_split = 1 - test_split - val_split self.transforms = transforms self.generate_data()
[docs] def generate_data(self): train_x = self.data_train["data_x"] train_y = self.data_train["data_y"] train_hum_preds = self.data_train["hum_preds"] if self.data_test is not None: test_x = self.data_test["data_x"] test_y = self.data_test["data_y"] test_h = self.data_test["hum_preds"] train_size = int((1 - self.val_split) * self.total_samples) val_size = int(self.val_split * self.total_samples) train_x, val_x = torch.utils.data.random_split( train_x, [train_size, val_size], generator=torch.Generator().manual_seed(42), ) train_y, val_y = torch.utils.data.random_split( train_y, [train_size, val_size], generator=torch.Generator().manual_seed(42), ) train_h, val_h = torch.utils.data.random_split( train_hum_preds, [train_size, val_size], generator=torch.Generator().manual_seed(42), ) self.data_train = torch.utils.data.TensorDataset( train_x.dataset.data[train_x.indices], train_y.dataset.data[train_y.indices], train_h.dataset.data[train_h.indices], ) self.data_val = torch.utils.data.TensorDataset( val_x.dataset.data[val_x.indices], val_y.dataset.data[val_y.indices], val_h.dataset.data[val_h.indices], ) self.data_test = torch.utils.data.TensorDataset(test_x, test_y, test_h) else: train_size = int(self.train_split * self.total_samples) val_size = int(self.val_split * self.total_samples) test_size = self.total_samples - train_size - val_size train_x, val_x, test_x = torch.utils.data.random_split( train_x, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42), ) train_y, val_y, test_y = torch.utils.data.random_split( train_y, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42), ) train_h, val_h, test_h = torch.utils.data.random_split( train_hum_preds, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42), ) self.data_train = torch.utils.data.TensorDataset( train_x.dataset.data[train_x.indices], train_y.dataset.data[train_y.indices], train_h.dataset.data[train_h.indices], ) self.data_val = torch.utils.data.TensorDataset( val_x.dataset.data[val_x.indices], val_y.dataset.data[val_y.indices], val_h.dataset.data[val_h.indices], ) self.data_test = torch.utils.data.TensorDataset( test_x.dataset.data[test_x.indices], test_y.dataset.data[test_y.indices], test_h.dataset.data[test_h.indices], ) self.data_train_loader = torch.utils.data.DataLoader( self.data_train, batch_size=self.batch_size, shuffle=True ) self.data_val_loader = torch.utils.data.DataLoader( self.data_val, batch_size=self.batch_size, shuffle=True ) self.data_test_loader = torch.utils.data.DataLoader( self.data_test, batch_size=self.batch_size, shuffle=True )