Source code for sconce.data_generators.image_data_generator

from .base import DataGenerator
from torch.utils import data
from torchvision import datasets, transforms

import numpy as np
import os
import pandas as pd
import seaborn as sns
import tempfile


[docs]def get_image_info(image): return {'height': image.height, 'width': image.width, 'num_channels': len(image.getbands())}
[docs]class ImageDataGenerator(DataGenerator): """ A DataGenerator class with some handy methods for image type data. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.summary_df = None @property def num_channels(self): """ The number of image channels, based on looking at the first image in the dataset. """ dataset = self.real_dataset for image, target in dataset: return image.shape[0]
[docs] def get_summary_df(self): """ Return a pandas dataframe that summarizes the image metadata in the dataset. """ if self.summary_df is None: self.summary_df = self._get_summary_df() return self.summary_df
def _get_summary_df(self): info_list = [] dataset = self.real_dataset old_transform = self.real_dataset.transform old_target_tranform = self.real_dataset.target_transform try: dataset.transform = None dataset.target_transform = None for image, label in dataset: info = get_image_info(image) info['label'] = label info_list.append(info) except Exception: pass dataset.transform = old_transform dataset.target_transform = old_target_tranform return pd.DataFrame(info_list)
[docs] def plot_label_summary(self): """ Generate a barchart showing how many images of each label there are. """ summary_df = self.get_summary_df() return sns.countplot(x='label', data=summary_df)
[docs] def plot_size_summary(self): """ Generate a scatter plot showing the sizes of the images in the dataset. """ summary_df = self.get_summary_df() return sns.jointplot(x="height", y="width", kind='scatter', stat_func=None, data=summary_df)
[docs] @classmethod def from_torchvision(cls, batch_size=500, data_location=None, dataset_class=datasets.MNIST, fraction=1.0, num_workers=1, pin_memory=True, shuffle=True, train=True, transform=transforms.ToTensor()): """ Create a DataGenerator from a torchvision dataset class. Arguments: batch_size (int): how large the yielded `inputs` and `targets` should be. See :py:class:`DataLoader` for details. data_location (path): where downloaded dataset should be stored. If ``None`` a system dependent temporary location will be used. dataset_class (class): a torchvision dataset class that supports constructor arguments {'root', 'train', 'download', 'transform'}. For example, MNIST, FashionMnist, CIFAR10, or CIFAR100. fraction (float): (0.0 - 1.0] how much of the original dataset's data to use. num_workers (int): how many subprocesses to use for data loading. See :py:class:`DataLoader` for details. pin_memory (bool): if ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. See :py:class:`DataLoader` for details. shuffle (bool): set to ``True`` to have the data reshuffled at every epoch. See :py:class:`DataLoader` for details. train (bool): if ``True``, creates dataset from training set, otherwise creates from test set. transform (callable): a function/transform that takes in an PIL image and returns a transformed version. """ assert(fraction > 0.0) assert(fraction <= 1.0) if data_location is None: data_location = os.path.join(tempfile.gettempdir(), dataset_class.__name__) dataset = dataset_class(data_location, train=train, download=True, transform=transform) indices = [int(x) for x in np.linspace( start=0, stop=len(dataset) - 1, num=int(len(dataset) * fraction))] subset = data.dataset.Subset(dataset, indices=indices) return cls.from_dataset(subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle)
[docs] @classmethod def from_image_folder(self, root, loader_kwargs=None, **dataset_kwargs): """ Create a DataGenerator from a folder of images. See :py:class:`torchvision.datasets.ImageFolder`. Arguments: root (path): the root directory path. loader_kwargs (dict): keyword args provided to the DataLoader constructor. **dataset_kwargs: keyword args provided to the :py:class:`torchvision.datasets.ImageFolder` constructor. """ if loader_kwargs is None: loader_kwargs = {} dataset = datasets.ImageFolder(root=root, **dataset_kwargs) return self.from_dataset(dataset, **loader_kwargs)