Source code for sconce.datasets.simple_dots

from PIL import Image, ImageDraw
from torchvision import transforms

import numpy as np
import torch
import torch.utils.data as data


[docs]def generate_circle_image(x, y, image_size=(40, 40), circle_radius=4): assert x >= 0, 'x must be between 0.0 and 1.0' assert x <= 1, 'x must be between 0.0 and 1.0' assert y >= 0, 'y must be between 0.0 and 1.0' assert y <= 1, 'y must be between 0.0 and 1.0' x_pixel = (image_size[0] - 2 * circle_radius - 1) * x + circle_radius y_pixel = (image_size[1] - 2 * circle_radius - 1) * y + circle_radius image = Image.new('L', image_size) draw = ImageDraw.Draw(image) upper_left = (x_pixel - circle_radius, y_pixel - circle_radius) lower_right = (x_pixel + circle_radius, y_pixel + circle_radius) draw.ellipse((*upper_left, *lower_right), fill='white', outline='white') return image
[docs]class SimpleDots(data.Dataset): """ A Dataset where each image is a white dot on a black background, the targets are the coordinates of the dot. Arguments: image_size (tuple of int): the size of the image (in pixels). num_images (int): the number of images to generate for this dataset. Best if the sqrt of this number is an integer. circle_radius (float): the size of the dot (in pixels). New in 0.11.0 """ def __init__(self, circle_radius=4, image_size=(30, 30), num_images=400): self.circle_radius = circle_radius self.image_size = image_size self.num_images = num_images self.samples, self.targets = self._generate() def _generate(self): num_xs = int(np.sqrt(self.num_images)) num_ys = self.num_images // num_xs xs = np.linspace(0.0, 1.0, num_xs) ys = np.linspace(0.0, 1.0, num_ys) transform = transforms.ToTensor() images = [] coords = [] for x in xs: for y in ys: pil_image = generate_circle_image(x, y, image_size=self.image_size, circle_radius=self.circle_radius) image = transform(pil_image) images.append(image) coords.append(torch.Tensor((x, y))) samples = torch.stack(images) targets = torch.stack(coords) return samples, targets def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where sample is the image, and target is the x, y coordinates of the dot in the image. """ return self.samples[index], self.targets[index] def __len__(self): return len(self.samples) def __repr__(self): fmt_str = 'SimpleDots Dataset\n' fmt_str += f' Size of images: {self.image_size}\n' fmt_str += f' Radius of circle: {self.circle_radius}\n' fmt_str += f' Number of images: {len(self)}\n' return fmt_str