Source code for sconce.datasets.csv_image_folder

from collections import defaultdict
from sconce import transforms
from torchvision.datasets import folder

import csv
import os
import os.path
import torch.utils.data as data


[docs]class CsvImageFolder(data.Dataset): """ A Dataset that reads images from a folder and classes from a csv file. Arguments: root (string): directory where the images can be found. csv_path (string): the path to the csv file containing image filenames and classes. filename_key (string, optional): the column header of the csv for the column that contains image filenames (without extensions). classes_key (string, optional): the column header of the csv for the column that contains classes for each image. csv_delimiter (string, optional): the character(s) used to separate fields in the csv file. loader (callable, optional): a function to load a sample given its path. extensions (list[string], optinoal): a list of allowed extensions. E.g, ``['.jpg', '.tif']`` transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Attributes: class_to_idx (dict): a dictionary mapping class names to indices. classes (list[string]): the human readable names of the classes that images can belong to. paths (list[string]): for each image, the path to the image on disk. targets (list[list[int]]): for each image, a list of class indices to which that image belongs. """ def __init__(self, root, csv_path, filename_key='image_name', classes_key='tags', csv_delimiter=',', classes_delimiter=' ', loader=folder.default_loader, extensions=folder.IMG_EXTENSIONS, transform=None, target_transform=transforms.NHot): self.root = root self.csv_path = csv_path self.filename_key = filename_key self.classes_key = classes_key self.csv_delimiter = csv_delimiter self.classes_delimiter = classes_delimiter self.loader = loader self.extensions = extensions self.transform = transform self._found_extensions = None self._load_found_extensions() self.class_to_idx = {} self.classes = [] self.paths = [] self.targets = [] self._load() if target_transform is transforms.NHot: self.target_transform = transforms.NHot(size=len(self.classes)) else: self.target_transform = target_transform def _load_found_extensions(self): found_extensions = defaultdict(list) for filename in os.listdir(self.root): base, ext = os.path.splitext(filename) found_extensions[base].append(ext) return dict(found_extensions) @property def found_extensions(self): if self._found_extensions is None: self._found_extensions = self._load_found_extensions() return self._found_extensions def _load(self): classes_set = set() classes_list = [] with open(self.csv_path) as csv_file: reader = csv.DictReader(csv_file, delimiter=self.csv_delimiter) for row in reader: filename = self._get_filename(row) path = self._get_path(filename) self.paths.append(path) classes = self._get_classes(row) classes_list.append(classes) classes_set.update(classes) self.classes = sorted(classes_set) self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} targets = [] for classes in classes_list: targets.append([self.class_to_idx[_class] for _class in classes]) self.targets = targets def _get_filename(self, row): return row[self.filename_key] def _get_path(self, base_filename): found_extensions = self.found_extensions if base_filename in found_extensions: found_extensions = found_extensions[base_filename] for extension in self.extensions: if extension in found_extensions: return os.path.join(self.root, '%s%s' % (base_filename, extension)) raise RuntimeError(f"No image file with base filename ({base_filename}) " f"found in folder ({self.root}), valid extensions are: {self.extensions}") def _get_classes(self, row): return row[self.classes_key].split(self.classes_delimiter)
[docs] def get_sample(self, index): path = self.paths[index] return self.loader(path)
[docs] def get_target(self, index): return self.targets[index]
def _get_target(self, row): classes = self._get_classes(row) return [self.class_to_idx[_class] for _class in classes] @property def num_classes(self): return len(self.classes) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where sample is the image, and target is an array of indices of the target class. """ sample = self.get_sample(index) target = self.get_target(index) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.paths) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of images: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) fmt_str += ' Number of different classes: {}\n'.format(len(self.classes)) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str