Source code for pytorch_eo.datasets.eurosat.EuroSATBase

from pytorch_eo.utils import download_url, unzip_file
import pytorch_lightning as pl
import os
from sklearn.model_selection import train_test_split
from pathlib import Path
from torch.utils.data import DataLoader


[docs]class EuroSATBase(pl.LightningDataModule): def __init__(self, batch_size, download, url, path, compressed_data_filename, data_folder, test_size, val_size, random_state, num_workers, pin_memory, shuffle, verbose ): super().__init__() self.batch_size = batch_size self.download = download self.path = Path(path) self.url = url self.compressed_data_filename = compressed_data_filename self.data_folder = data_folder self.num_classes = 10 self.test_size = test_size self.val_size = val_size self.random_state = random_state self.num_workers = num_workers self.pin_memory = pin_memory self.shuffle = shuffle self.verbose = verbose
[docs] def setup(self, stage=None): # download data compressed_data_path = self.path / self.compressed_data_filename uncompressed_data_path = self.path / self.data_folder if self.download: # create data folder os.makedirs(self.path, exist_ok=True) # check data is not already downloaded if not os.path.isfile(compressed_data_path): print("downloading data ...") download_url(self.url, compressed_data_path) else: print("data already downloaded !") # extract if not os.path.isdir(uncompressed_data_path): unzip_file(compressed_data_path, self.path, msg="extracting data ...") else: print("data already extracted !") # TODO: check data is correct else: assert os.path.isdir(uncompressed_data_path), 'data not found' # TODO: check data is correct # retrieve classes from folder structure self.classes = sorted(os.listdir(uncompressed_data_path)) assert len(self.classes) == self.num_classes # generate list of images and labels images, encoded = [], [] for ix, label in enumerate(self.classes): _images = os.listdir(uncompressed_data_path / label) images += [uncompressed_data_path / label / img for img in _images] encoded += [ix]*len(_images) if self.verbose: print(f'Number of images: {len(images)}') # data splits train_images, self.test_images, train_labels, self.test_labels = train_test_split( images, encoded, stratify=encoded, test_size=self.test_size, random_state=self.random_state ) self.train_images, self.val_images, self.train_labels, self.val_labels = train_test_split( train_images, train_labels, stratify=train_labels, test_size=self.val_size, random_state=self.random_state ) if self.verbose: print("training samples", len(self.train_images)) print("validation samples", len(self.val_images)) print("test samples", len(self.test_images))
[docs] def train_dataloader(self): return DataLoader( self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=self.shuffle )
[docs] def val_dataloader(self): return DataLoader( self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=False )
[docs] def test_dataloader(self): return DataLoader( self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=False )