Source code for thelper.data.utils

"""Dataset utility functions and tools.

This module contains utility functions and tools used to instantiate data loaders and parsers.
"""

import inspect
import json
import logging
import os
import sys

import h5py
import numpy as np
import tqdm

import thelper.tasks
import thelper.transforms
import thelper.utils

logger = logging.getLogger(__name__)


[docs]def create_loaders(config, save_dir=None): """Prepares the task and data loaders for a model trainer based on a provided data configuration. This function will parse a configuration dictionary and extract all the information required to instantiate the requested dataset parsers. Then, combining the task metadata of all these parsers, it will evenly split the available samples into three sets (training, validation, test) to be handled by different data loaders. These will finally be returned along with the (global) task object. The configuration dictionary is expected to contain two fields: ``loaders``, which specifies all parameters required for establishing the dataset split, shuffling seeds, and batch size (these are listed and detailed below); and ``datasets``, which lists the dataset parser interfaces to instantiate as well as their parameters. For more information on the ``datasets`` field, refer to :func:`thelper.data.utils.create_parsers`. The parameters expected in the 'loaders' configuration field are the following: - ``<train_/valid_/test_>batch_size`` (mandatory): specifies the (mini)batch size to use in data loaders. If you get an 'out of memory' error at runtime, try reducing it. - ``<train_/valid_/test_>collate_fn`` (optional): specifies the collate function to use in data loaders. The default one is typically fine, but some datasets might require a custom function. - ``shuffle`` (optional, default=True): specifies whether the data loaders should shuffle their samples or not. - ``test_seed`` (optional): specifies the RNG seed to use when splitting test data. If no seed is specified, the RNG will be initialized with a device-specific or time-related seed. - ``valid_seed`` (optional): specifies the RNG seed to use when splitting validation data. If no seed is specified, the RNG will be initialized with a device-specific or time-related seed. - ``torch_seed`` (optional): specifies the RNG seed to use for torch-related stochastic operations (e.g. for data augmentation). If no seed is specified, the RNG will be initialized with a device-specific or time-related seed. - ``numpy_seed`` (optional): specifies the RNG seed to use for numpy-related stochastic operations (e.g. for data augmentation). If no seed is specified, the RNG will be initialized with a device-specific or time-related seed. - ``random_seed`` (optional): specifies the RNG seed to use for stochastic operations with python's 'random' package. If no seed is specified, the RNG will be initialized with a device-specific or time-related seed. - ``workers`` (optional, default=1): specifies the number of threads to use to preload batches in parallel; can be 0 (loading will be on main thread), or an integer >= 1. - ``pin_memory`` (optional, default=False): specifies whether the data loaders will copy tensors into CUDA-pinned memory before returning them. - ``drop_last`` (optional, default=False): specifies whether to drop the last incomplete batch or not if the dataset size is not a multiple of the batch size. - ``sampler`` (optional): specifies a type of sampler and its constructor parameters to be used in the data loaders. This can be used for example to help rebalance a dataset based on its class distribution. See :mod:`thelper.data.samplers` for more information. - ``augments`` (optional): provides a list of transformation operations used to augment all samples of a dataset. See :func:`thelper.transforms.utils.load_augments` for more info. - ``train_augments`` (optional): provides a list of transformation operations used to augment the training samples of a dataset. See :func:`thelper.transforms.utils.load_augments` for more info. - ``valid_augments`` (optional): provides a list of transformation operations used to augment the validation samples of a dataset. See :func:`thelper.transforms.utils.load_augments` for more info. - ``test_augments`` (optional): provides a list of transformation operations used to augment the test samples of a dataset. See :func:`thelper.transforms.utils.load_augments` for more info. - ``eval_augments`` (optional): provides a list of transformation operations used to augment the validation and test samples of a dataset. See :func:`thelper.transforms.utils.load_augments` for more info. - ``base_transforms`` (optional): provides a list of transformation operations to apply to all loaded samples. This list will be passed to the constructor of all instantiated dataset parsers. See :func:`thelper.transforms.utils.load_transforms` for more info. - ``train_split`` (optional): provides the proportion of samples of each dataset to hand off to the training data loader. These proportions are given in a dictionary format (``name: ratio``). - ``valid_split`` (optional): provides the proportion of samples of each dataset to hand off to the validation data loader. These proportions are given in a dictionary format (``name: ratio``). - ``test_split`` (optional): provides the proportion of samples of each dataset to hand off to the test data loader. These proportions are given in a dictionary format (``name: ratio``). - ``skip_verif`` (optional, default=True): specifies whether the dataset split should be verified if resuming a session by parsing the log files generated earlier. - ``skip_split_norm`` (optional, default=False): specifies whether the question about normalizing the split ratios should be skipped or not. - ``skip_class_balancing`` (optional, default=False): specifies whether the balancing of class labels should be skipped in case the task is classification-related. Example configuration file:: # ... "loaders": { "batch_size": 128, # batch size to use in data loaders "shuffle": true, # specifies that the data should be shuffled "workers": 4, # number of threads to pre-fetch data batches with "sampler": { # we can use a data sampler to rebalance classes (optional) # see e.g. 'thelper.data.samplers.WeightedSubsetRandomSampler' # ... }, "train_augments": { # training data augmentation operations # see 'thelper.transforms.utils.load_augments' # ... }, "eval_augments": { # evaluation (valid/test) data augmentation operations # see 'thelper.transforms.utils.load_augments' # ... }, "base_transforms": { # global sample transformation operations # see 'thelper.transforms.utils.load_transforms' # ... }, # finally, we define a 80%-10%-10% split for our data # (we could instead use one dataset for training and one for testing) "train_split": { "dataset_A": 0.8 "dataset_B": 0.8 }, "valid_split": { "dataset_A": 0.1 "dataset_B": 0.1 }, "test_split": { "dataset_A": 0.1 "dataset_B": 0.1 } # (note that the dataset names above are defined in the field below) }, "datasets": { "dataset_A": { # type of dataset interface to instantiate "type": "...", "params": { # ... } }, "dataset_B": { # type of dataset interface to instantiate "type": "...", "params": { # ... }, # if it does not derive from 'thelper.data.parsers.Dataset', a task is needed: "task": { # this type must derive from 'thelper.tasks.Task' "type": "...", "params": { # ... } } }, # ... }, # ... Args: config: a dictionary that provides all required data configuration information under two fields, namely 'datasets' and 'loaders'. save_dir: the path to the root directory where the session directory should be saved. Note that this is not the path to the session directory itself, but its parent, which may also contain other session directories. Returns: A 4-element tuple that contains: 1) the global task object to specialize models and trainers with; 2) the training data loader; 3) the validation data loader; and 4) the test data loader. .. seealso:: | :func:`thelper.data.utils.create_parsers` | :func:`thelper.transforms.utils.load_augments` | :func:`thelper.transforms.utils.load_transforms` """ logstamp = thelper.utils.get_log_stamp() repover = thelper.__version__ + ":" + thelper.utils.get_git_stamp() session_name = config["name"] if "name" in config else "session" data_logger_dir = None if save_dir is not None: data_logger_dir = os.path.join(save_dir, "logs") os.makedirs(data_logger_dir, exist_ok=True) data_logger_path = os.path.join(data_logger_dir, "data.log") data_logger_format = logging.Formatter("[%(asctime)s - %(process)s] %(levelname)s : %(message)s") data_logger_fh = logging.FileHandler(data_logger_path) data_logger_fh.setFormatter(data_logger_format) thelper.data.logger.addHandler(data_logger_fh) thelper.data.logger.info("created data log for session '%s'" % session_name) logger.debug("loading data usage config") # todo: 'data_config' field is deprecated, might be removed later if "data_config" in config: logger.warning("using 'data_config' field in configuration dictionary is deprecated; switch it to 'loaders'") loaders_config = thelper.utils.get_key(["data_config", "loaders"], config) # noinspection PyProtectedMember from thelper.data.loaders import LoaderFactory as LoaderFactory loader_factory = LoaderFactory(loaders_config) datasets, task = create_parsers(config, loader_factory.get_base_transforms()) if not datasets or task is None: raise AssertionError("invalid dataset configuration (got empty list)") for dataset_name, dataset in datasets.items(): logger.info("parsed dataset: %s" % str(dataset)) logger.info("task info: %s" % str(task)) logger.debug("splitting datasets and creating loaders...") train_idxs, valid_idxs, test_idxs = loader_factory.get_split(datasets, task) if save_dir is not None: with open(os.path.join(data_logger_dir, "task.log"), "a+") as fd: fd.write("session: %s-%s\n" % (session_name, logstamp)) fd.write("version: %s\n" % repover) fd.write(str(task) + "\n") for dataset_name, dataset in datasets.items(): dataset_log_file = os.path.join(data_logger_dir, dataset_name + ".log") if not loader_factory.skip_verif and os.path.isfile(dataset_log_file): logger.info("verifying sample list for dataset '%s'..." % dataset_name) with open(dataset_log_file, "r") as fd: log_content = fd.read() if not log_content or log_content[0] != "{": # could not find new style (json) dataset log, cannot easily parse and compare this log logger.warning("cannot verify that old split is similar to new split, log is out-of-date") continue log_content = json.loads(log_content) if "samples" not in log_content or not isinstance(log_content["samples"], list): raise AssertionError("unexpected dataset log content (bad 'samples' field)") samples_old = log_content["samples"] samples_new = dataset.samples if hasattr(dataset, "samples") and isinstance(dataset.samples, list) else [] if len(samples_old) != len(samples_new): query_msg = "Old sample list for dataset '%s' mismatch with current sample list; proceed anyway?" answer = thelper.utils.query_yes_no(query_msg, bypass="n") if not answer: logger.error("sample list mismatch with previous run; user aborted") sys.exit(1) break else: breaking = False for set_name, idxs in zip(["train_idxs", "valid_idxs", "test_idxs"], [train_idxs[dataset_name], valid_idxs[dataset_name], test_idxs[dataset_name]]): # index values were paired in tuples earlier, 0=idx, 1=label if log_content[set_name] != [idx for idx, _ in idxs]: query_msg = "Old indices list for dataset '%s' mismatch with current indices" \ "list ('%s'); proceed anyway?" % (dataset_name, set_name) answer = thelper.utils.query_yes_no(query_msg, bypass="n") if not answer: logger.error("indices list mismatch with previous run; user aborted") sys.exit(1) breaking = True break if not breaking: for idx, (sample_new, sample_old) in enumerate(zip(samples_new, samples_old)): if str(sample_new) != sample_old: query_msg = "Old sample #%d for dataset '%s' mismatch with current #%d; proceed anyway?" \ "\n\told: %s\n\tnew: %s" % (idx, dataset_name, idx, str(sample_old), str(sample_new)) answer = thelper.utils.query_yes_no(query_msg, bypass="n") if not answer: logger.error("sample list mismatch with previous run; user aborted") sys.exit(1) break for dataset_name, dataset in datasets.items(): dataset_log_file = os.path.join(data_logger_dir, dataset_name + ".log") samples = dataset.samples if hasattr(dataset, "samples") and isinstance(dataset.samples, list) else [] log_content = { "metadata": { "session_name": session_name, "logstamp": logstamp, "version": repover, "dataset": str(dataset), }, "samples": [str(sample) for sample in samples], # index values were paired in tuples earlier, 0=idx, 1=label "train_idxs": [idx for idx, _ in train_idxs[dataset_name]], "valid_idxs": [idx for idx, _ in valid_idxs[dataset_name]], "test_idxs": [idx for idx, _ in test_idxs[dataset_name]] } # now, always overwrite, as it can get too big otherwise with open(dataset_log_file, "w") as fd: json.dump(log_content, fd, indent=4, sort_keys=False) train_loader, valid_loader, test_loader = loader_factory.create_loaders(datasets, train_idxs, valid_idxs, test_idxs) return task, train_loader, valid_loader, test_loader
[docs]def create_parsers(config, base_transforms=None): """Instantiates dataset parsers based on a provided dictionary. This function will instantiate dataset parsers as defined in a name-type-param dictionary. If multiple datasets are instantiated, this function will also verify their task compatibility and return the global task. The dataset interfaces themselves should be derived from :class:`thelper.data.parsers.Dataset`, be compatible with :class:`thelper.data.parsers.ExternalDataset`, or should provide a 'task' field specifying all the information related to sample dictionary keys and model i/o. The provided configuration will be parsed for a 'datasets' dictionary entry. The keys in this dictionary are treated as unique dataset names and are used for lookups. The value associated to each key (or dataset name) should be a type-params dictionary that can be parsed to instantiate the dataset interface. An example configuration dictionary is given in :func:`thelper.data.utils.create_loaders`. Args: config: a dictionary that provides unique dataset names and parameters needed for instantiation under the 'datasets' field. base_transforms: the transform operation that should be applied to all loaded samples, and that will be provided to the constructor of all instantiated dataset parsers. Returns: A 2-element tuple that contains: 1) the list of dataset interfaces/parsers that were instantiated; and 2) a task object compatible with all of those (see :class:`thelper.tasks.utils.Task` for more information). .. seealso:: | :func:`thelper.data.utils.create_loaders` | :class:`thelper.data.parsers.Dataset` | :class:`thelper.data.parsers.ExternalDataset` | :class:`thelper.tasks.utils.Task` """ if not isinstance(config, dict): raise AssertionError("unexpected session config type") if "datasets" not in config or not config["datasets"]: raise AssertionError("config missing 'datasets' field (must contain dict or str value)") config = config["datasets"] # no need to keep the full config here if isinstance(config, str): if os.path.isfile(config) and os.path.splitext(config)[1] == ".json": with open(config) as fd: config = json.load(fd) else: raise AssertionError("'datasets' string should point to valid json file") logger.debug("loading datasets templates") if not isinstance(config, dict): raise AssertionError("invalid datasets config type (must be dictionary)") datasets = {} tasks = [] for dataset_name, dataset_config in config.items(): if isinstance(dataset_config, thelper.data.Dataset): dataset = dataset_config task = dataset.task else: logger.debug("loading dataset '%s' configuration..." % dataset_name) if "type" not in dataset_config: raise AssertionError("missing field 'type' for instantiation of dataset '%s'" % dataset_name) dataset_type = thelper.utils.import_class(dataset_config["type"]) dataset_params = thelper.utils.get_key_def(["params", "parameters"], dataset_config, {}) transforms = None if "transforms" in dataset_config and dataset_config["transforms"]: logger.debug("loading custom transforms for dataset '%s'..." % dataset_name) transforms = thelper.transforms.load_transforms(dataset_config["transforms"]) if base_transforms is not None: transforms = thelper.transforms.Compose([transforms, base_transforms]) elif base_transforms is not None: transforms = base_transforms if issubclass(dataset_type, thelper.data.Dataset): # assume that the dataset is derived from thelper.data.parsers.Dataset (it is fully sampling-ready) dataset_sig = inspect.signature(dataset_type) if "config" in dataset_sig.parameters: # pragma: no cover # @@@@ for backward compatibility only, will be removed in v0.3 dataset = dataset_type(transforms=transforms, config=dataset_params) else: dataset = dataset_type(transforms=transforms, **dataset_params) if "task" in dataset_config: logger.warning("'task' field detected in dataset '%s' config; dataset's default task will be ignored" % dataset_name) task = thelper.tasks.create_task(dataset_config["task"]) else: task = dataset.task else: if "task" not in dataset_config or not dataset_config["task"]: raise AssertionError("external dataset '%s' must define task interface in its configuration dict" % dataset_name) task = thelper.tasks.create_task(dataset_config["task"]) # assume that __getitem__ and __len__ are implemented, but we need to make it sampling-ready dataset = thelper.data.ExternalDataset(dataset_type, task, transforms=transforms, **dataset_params) if task is None: raise AssertionError("parsed task interface should not be None anymore (old code doing something strange?)") tasks.append(task) datasets[dataset_name] = dataset return datasets, thelper.tasks.create_global_task(tasks)
[docs]def create_hdf5(archive_path, task, train_loader, valid_loader, test_loader, compression=None, config_backup=None): """Saves the samples loaded from train/valid/test data loaders into an HDF5 archive. The loaded minibatches are decomposed into individual samples. The keys provided via the task interface are used to fetch elements (input, groundtruth, ...) from the samples, and save them in the archive. The archive will contain three groups (`train`, `valid`, and `test`), and each group will contain a dataset for each element originally found in the samples. Note that the compression operates at the sample level, not at the dataset level. This means that elements of each sample will be compressed individually, not as an array. Therefore, if you are trying to compress very correlated samples (e.g. frames in a video sequence), this approach will be pretty bad. Args: archive_path: path pointing where the HDF5 archive should be created. task: task object that defines the input, groundtruth, and meta keys tied to elements that should be parsed from loaded samples and saved in the HDF5 archive. train_loader: training data loader (can be `None`). valid_loader: validation data loader (can be `None`). test_loader: testing data loader (can be `None`). compression: the compression configuration dictionary that will be parsed to determine how sample elements should be compressed. If a mapping is missing, that element will not be compressed. config_backup: optional session configuration file that should be saved in the HDF5 archive. Example compression configuration:: # the config is given as a dictionary { # each field is a key that corresponds to an element in each sample "key1": { # the 'type' identifies the compression approach to use # (see thelper.utils.encode_data for more information) "type": "jpg", # extra parameters might be needed to encode the data # (see thelper.utils.encode_data for more information) "encode_params": {} # these parameters are packed and kept for decoding # (see thelper.utils.decode_data for more information) "decode_params": {"flags": "cv.IMREAD_COLOR"} }, "key2": { # this explicitly means that no encoding should be performed "type": "none" }, ... # if a key is missing, its elements will not be compressed } .. seealso:: | :func:`thelper.cli.split_data` | :class:`thelper.data.parsers.HDF5Dataset` | :func:`thelper.utils.encode_data` | :func:`thelper.utils.decode_data` """ if compression is None: compression = {} if config_backup is None: config_backup = {} with h5py.File(archive_path, "w") as fd: fd.attrs["source"] = thelper.utils.get_log_stamp() fd.attrs["git_sha1"] = thelper.utils.get_git_stamp() fd.attrs["version"] = thelper.__version__ fd.attrs["task"] = str(task) fd.attrs["config"] = str(config_backup) fd.attrs["compression"] = str(compression) dtype = h5py.special_dtype(vlen=np.uint8) target_keys = task.keys def create_dataset(name, max_len, array_template, compr_args): if array_template.ndim > 1: dset = fd.create_dataset(name, shape=(max_len,), maxshape=(max_len,), dtype=dtype) dset.attrs["orig_dtype"] = str(array_template.dtype) dset.attrs["orig_shape"] = array_template.shape[1:] # removes batch dim else: assert thelper.utils.is_scalar(array_template[0]) if "type" in compr_args and compr_args["type"] != "none": raise AssertionError("cannot compress scalar elements") if np.issubdtype(array_template.dtype, np.number): dset = fd.create_dataset(name, shape=(max_len,), maxshape=(max_len,), dtype=array_template.dtype) else: dset = fd.create_dataset(name, shape=(max_len,), maxshape=(max_len,), dtype=dtype) dset.attrs["orig_dtype"] = str(array_template.dtype) dset.attrs["orig_shape"] = () return dset def fill_dataset(dset, dset_idx, array_idx, array, compr, **compr_kwargs): if array.ndim > 1: dset[dset_idx] = np.frombuffer(thelper.utils.encode_data(array[array_idx], compr, **compr_kwargs), dtype=np.uint8) elif not np.issubdtype(array.dtype, np.number): sample = array[array_idx].tobytes() if np.issubdtype(array[array_idx].dtype, np.dtype(str).type) else array[array_idx] dset[dset_idx] = np.frombuffer(sample, dtype=np.uint8) else: dset[dset_idx] = array[array_idx] def get_compr_args(key, config): config = thelper.utils.get_key_def(key, config, default={}) compr_type = thelper.utils.get_key_def("type", config, default="none") encode_params = thelper.utils.get_key_def("encode_params", config, default={}) return compr_type, encode_params for loader, group in [(train_loader, "train"), (valid_loader, "valid"), (test_loader, "test")]: if loader is None: continue max_dataset_len = len(loader) * loader.batch_size datasets = {key: None for key in target_keys} datasets_len = {key: 0 for key in target_keys} datasets_compr = {key: get_compr_args(key, compression) for key in target_keys} for batch in tqdm.tqdm(loader, desc=f"packing {group} loader"): for key in target_keys: tensor = thelper.utils.to_numpy(batch[key]) if datasets[key] is None: datasets[key] = create_dataset(group + "/" + key, max_dataset_len, tensor, datasets_compr[key]) for idx in range(tensor.shape[0]): fill_dataset(datasets[key], datasets_len[key], idx, tensor, datasets_compr[key][0], **datasets_compr[key][1]) datasets_len[key] += 1 assert len(set(datasets_len.values())) == 1 fd[group].attrs["count"] = datasets_len[task.input_key] for key in target_keys: datasets[key].resize(size=(datasets_len[key],))
[docs]def get_class_weights(label_map, stype, invmax, maxw=float('inf'), minw=0.0, norm=True): """Returns a map of adjusted class weights based on a given rebalancing strategy. Args: label_map: map of index lists or sample counts tied to class labels. stype: weighting strategy ('uniform', 'linear', or 'rootX'); see :class:`thelper.data.samplers.WeightedSubsetRandomSampler` for more information on these. invmax: specifies whether to max-invert the weight vector (thus creating cost factors) or not (default=True). maxw: maximum allowed weight value (applied after invmax, if required). minw: minimum allowed weight value (applied after invmax, if required). norm: specifies whether the returned weights should be normalized (default=True, i.e. normalized). Returns: Map of adjusted weights tied to class labels. .. seealso:: | :class:`thelper.data.samplers.WeightedSubsetRandomSampler` """ if not isinstance(label_map, dict) or any([not isinstance(val, (list, int)) for val in label_map.values()]): raise AssertionError("unexpected label map type") if stype == "uniform": label_weights = {label: 1.0 / len(label_map) for label in label_map} elif stype == "linear" or "root" in stype: if stype == "root" or stype == "linear": rpow = 1.0 else: rpow = 1.0 / float(stype.split("root", 1)[1]) label_sizes = {label: len(v) if isinstance(v, list) else v for label, v in label_map.items()} label_weights = {label: (lsize / sum(label_sizes.values())) ** rpow for label, lsize in label_sizes.items()} else: raise AssertionError("unknown label weighting strategy") if invmax: label_weights = {label: max(label_weights.values()) / max(weight, 1e-6) for label, weight in label_weights.items()} label_weights = {label: min(max(weight, minw), maxw) for label, weight in label_weights.items()} if norm: tot_weight = sum([w for w in label_weights.values()]) label_weights = {label: weight / tot_weight for label, weight in label_weights.items()} return label_weights