"""General utilities module.
This module only contains non-ML specific functions, i/o helpers,
and matplotlib/pyplot drawing calls.
"""
import copy
import errno
import functools
import glob
import importlib
import importlib.util
import inspect
import io
import itertools
import json
import logging
import math
import os
import pathlib
import pickle
import platform
import re
import sys
import time
from typing import TYPE_CHECKING
import cv2 as cv
import lz4
import matplotlib.pyplot as plt
import numpy as np
import sklearn.metrics
import torch
import thelper.typedefs # noqa: F401
if TYPE_CHECKING:
from typing import List, Optional, Type # noqa: F401
from types import FunctionType # noqa: F401
logger = logging.getLogger(__name__)
bypass_queries = False
[docs]class Struct:
"""Generic runtime-defined C-like data structure (maps constructor elements to fields)."""
[docs] def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
def __repr__(self):
return self.__class__.__module__ + "." + self.__class__.__qualname__ + \
"(" + ", ".join([f"{key}={repr(val)}" for key, val in self.__dict__.items()]) + ")"
[docs]def get_available_cuda_devices(attempts_per_device=5):
# type: (Optional[int]) -> List[int]
"""
Tests all visible cuda devices and returns a list of available ones.
Returns:
List of available cuda device IDs (integers). An empty list means no
cuda device is available, and the app should fallback to cpu.
"""
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
return []
devices_available = [False] * torch.cuda.device_count()
attempt_broadcast = False
for attempt in range(attempts_per_device):
for device_id in range(torch.cuda.device_count()):
if not devices_available[device_id]:
if not attempt_broadcast:
logger.debug("testing availability of cuda device #%d (%s)" % (
device_id, torch.cuda.get_device_name(device_id)
))
# noinspection PyBroadException
try:
torch.cuda.set_device(device_id)
test_val = torch.cuda.FloatTensor([1])
if test_val.cpu().item() != 1.0:
raise AssertionError("sometime's really wrong")
devices_available[device_id] = True
except Exception:
pass
attempt_broadcast = True
return [device_id for device_id, available in enumerate(devices_available) if available]
[docs]def setup_plt(config):
"""Parses the provided config for matplotlib flags and sets up its global state accordingly."""
config = get_key_def(["plt", "pyplot", "matplotlib"], config, {})
if "backend" in config:
import matplotlib
matplotlib.use(get_key("backend", config))
plt.interactive(get_key_def("interactive", config, False))
# noinspection PyUnusedLocal
[docs]def setup_cv2(config):
"""Parses the provided config for OpenCV flags and sets up its global state accordingly."""
# https://github.com/pytorch/pytorch/issues/1355
cv.setNumThreads(0)
cv.ocl.setUseOpenCL(False)
# todo: add more global opencv flags setups here
[docs]def setup_cudnn(config):
"""Parses the provided config for CUDNN flags and sets up PyTorch accordingly."""
if "cudnn" in config and isinstance(config["cudnn"], dict):
config = config["cudnn"]
if "benchmark" in config:
cudnn_benchmark_flag = str2bool(config["benchmark"])
logger.debug("cudnn benchmark mode = %s" % str(cudnn_benchmark_flag))
torch.backends.cudnn.benchmark = cudnn_benchmark_flag
if "deterministic" in config:
cudnn_deterministic_flag = str2bool(config["deterministic"])
logger.debug("cudnn deterministic mode = %s" % str(cudnn_deterministic_flag))
torch.backends.cudnn.deterministic = cudnn_deterministic_flag
else:
if "cudnn_benchmark" in config:
cudnn_benchmark_flag = str2bool(config["cudnn_benchmark"])
logger.debug("cudnn benchmark mode = %s" % str(cudnn_benchmark_flag))
torch.backends.cudnn.benchmark = cudnn_benchmark_flag
if "cudnn_deterministic" in config:
cudnn_deterministic_flag = str2bool(config["cudnn_deterministic"])
logger.debug("cudnn deterministic mode = %s" % str(cudnn_deterministic_flag))
torch.backends.cudnn.deterministic = cudnn_deterministic_flag
[docs]def setup_globals(config):
"""Parses the provided config for global flags and sets up the global state accordingly."""
if "bypass_queries" in config and config["bypass_queries"]:
global bypass_queries
bypass_queries = True
setup_plt(config)
setup_cv2(config)
setup_cudnn(config)
[docs]def load_checkpoint(ckpt, # type: thelper.typedefs.CheckpointLoadingType
map_location=None, # type: Optional[thelper.typedefs.MapLocationType]
always_load_latest=False, # type: Optional[bool]
check_version=True, # type: Optional[bool]
): # type: (...) -> thelper.typedefs.CheckpointContentType
"""Loads a session checkpoint via PyTorch, check its compatibility, and returns its data.
If the ``ckpt`` parameter is a path to a valid directory, then that directly will be searched for
a checkpoint. If multiple checkpoints are found, the latest will be returned (based on the epoch
index in its name). iF ``always_load_latest`` is set to False and if a checkpoint named
``ckpt.best.pth`` is found, it will be returned instead.
Args:
ckpt: a file-like object or a path to the checkpoint file or session directory.
map_location: a function, string or a dict specifying how to remap storage
locations. See ``torch.load`` for more information.
always_load_latest: toggles whether to always try to load the latest checkpoint
if a session directory is provided (instead of loading the 'best' checkpoint).
check_version: toggles whether the checkpoint's version should be checked for
compatibility issues, and query the user for how to proceed.
Returns:
Content of the checkpoint (a dictionary).
"""
if map_location is None and not get_available_cuda_devices():
map_location = 'cpu'
if isinstance(ckpt, str) and os.path.isdir(ckpt):
logger.debug("will search directory '%s' for a checkpoint to load..." % ckpt)
search_ckpt_dir = os.path.join(ckpt, "checkpoints")
if os.path.isdir(search_ckpt_dir):
search_dir = search_ckpt_dir
else:
search_dir = ckpt
ckpt_paths = glob.glob(os.path.join(search_dir, "ckpt.*.pth"))
if not ckpt_paths:
raise AssertionError("could not find any valid checkpoint files in directory '%s'" % search_dir)
latest_epoch, latest_day, latest_time = -1, -1, -1
for ckpt_path in ckpt_paths:
# note: the 2nd field in the name should be the epoch index, or 'best' if final checkpoint
split = os.path.basename(ckpt_path).split(".")
tag = split[1]
if tag == "best" and (not always_load_latest or latest_epoch == -1):
# if eval-only, always pick the best checkpoint; otherwise, only pick if nothing else exists
ckpt = ckpt_path
if not always_load_latest:
break
elif tag != "best":
log_stamp = split[2] if len(split) > 2 else ""
log_stamp = "fake-0-0" if log_stamp.count("-") != 2 else log_stamp
epoch_stamp, day_stamp, time_stamp = int(tag), int(log_stamp.split("-")[1]), int(log_stamp.split("-")[2])
if epoch_stamp > latest_epoch or day_stamp > latest_day or time_stamp > latest_time:
ckpt, latest_epoch, latest_day, latest_time = ckpt_path, epoch_stamp, day_stamp, time_stamp
if not os.path.isfile(ckpt):
raise AssertionError("could not find valid checkpoint at '%s'" % ckpt)
basepath = None
if isinstance(ckpt, str):
logger.debug("parsing checkpoint at '%s'" % ckpt)
basepath = os.path.dirname(os.path.abspath(ckpt))
else:
if hasattr(ckpt, "name"):
logger.debug("parsing checkpoint provided via file object")
basepath = os.path.dirname(os.path.abspath(ckpt.name))
ckptdata = torch.load(ckpt, map_location=map_location)
if not isinstance(ckptdata, dict):
raise AssertionError("unexpected checkpoint data type")
if check_version:
good_version = False
from thelper import __version__ as curr_ver
if "version" not in ckptdata:
logger.warning("checkpoint missing internal version tag")
ckpt_ver_str = "0.0.0"
else:
ckpt_ver_str = ckptdata["version"]
if not isinstance(ckpt_ver_str, str) or len(ckpt_ver_str.split(".")) != 3:
raise AssertionError("unexpected checkpoint version formatting")
# by default, checkpoints should be from the same minor version, we warn otherwise
versions = [curr_ver.split("."), ckpt_ver_str.split(".")]
if versions[0][0] != versions[1][0]:
logger.error("incompatible checkpoint, major version mismatch (%s vs %s)" % (curr_ver, ckpt_ver_str))
elif versions[0][1] != versions[1][1]:
logger.warning("outdated checkpoint, minor version mismatch (%s vs %s)" % (curr_ver, ckpt_ver_str))
else:
good_version = True
if not good_version:
answer = query_string("Checkpoint version unsupported (framework=%s, checkpoint=%s); how do you want to proceed?" %
(curr_ver, ckpt_ver_str), choices=["continue", "migrate", "abort"], default="migrate", bypass="migrate")
if answer == "abort":
logger.error("checkpoint out-of-date; user aborted")
sys.exit(1)
elif answer == "continue":
logger.warning("will attempt to load checkpoint anyway (might crash later due to incompatibilities)")
elif answer == "migrate":
ckptdata = migrate_checkpoint(ckptdata)
# load model trace if needed (we do it here since we can locate the neighboring file)
if "model" in ckptdata and isinstance(ckptdata["model"], str):
trace_path = None
if os.path.isfile(ckptdata["model"]):
trace_path = ckptdata["model"]
elif basepath is not None and os.path.isfile(os.path.join(basepath, ckptdata["model"])):
trace_path = os.path.join(basepath, ckptdata["model"])
if trace_path is not None:
if trace_path.endswith(".pth"):
ckptdata["model"] = torch.load(trace_path, map_location=map_location)
elif trace_path.endswith(".zip"):
ckptdata["model"] = torch.jit.load(trace_path, map_location=map_location)
return ckptdata
[docs]def migrate_checkpoint(ckptdata, # type: thelper.typedefs.CheckpointContentType
): # type: (...) -> thelper.typedefs.CheckpointContentType
"""Migrates the content of an incompatible or outdated checkpoint to the current version of the framework.
This function might not be able to fix all backward compatibility issues (e.g. it cannot fix class interfaces
that were changed). Perfect reproductibility of tests cannot be guaranteed either if this migration tool is used.
Args:
ckptdata: checkpoint data in dictionary form obtained via ``thelper.utils.load_checkpoint``. Note that
the data contained in this dictionary will be modified in-place.
Returns:
An updated checkpoint dictionary that should be compatible with the current version of the framework.
"""
if not isinstance(ckptdata, dict):
raise AssertionError("unexpected ckptdata type")
from thelper import __version__ as curr_ver
curr_ver = [int(num) for num in curr_ver.split(".")]
ckpt_ver_str = ckptdata["version"] if "version" in ckptdata else "0.0.0"
ckpt_ver = [int(num) for num in ckpt_ver_str.split(".")]
if (ckpt_ver[0] > curr_ver[0] or (ckpt_ver[0] == curr_ver[0] and ckpt_ver[1] > curr_ver[1]) or
(ckpt_ver[0:2] == curr_ver[0:2] and ckpt_ver[2] > curr_ver[2])):
raise AssertionError("cannot migrate checkpoints from future versions!")
if "config" not in ckptdata:
raise AssertionError("checkpoint migration requires config")
old_config = ckptdata["config"]
new_config = migrate_config(copy.deepcopy(old_config), ckpt_ver_str)
if ckpt_ver == [0, 0, 0]:
logger.warning("trying to migrate checkpoint data from v0.0.0; all bets are off")
else:
logger.info("trying to migrate checkpoint data from v%s" % ckpt_ver_str)
if ckpt_ver[0] <= 0 and ckpt_ver[1] <= 1:
# combine 'host' and 'time' fields into 'source'
if "host" in ckptdata and "time" in ckptdata:
ckptdata["source"] = ckptdata["host"] + ckptdata["time"]
del ckptdata["host"]
del ckptdata["time"]
# update classif task interface
if "task" in ckptdata and isinstance(ckptdata["task"], thelper.tasks.classif.Classification):
ckptdata["task"] = str(thelper.tasks.classif.Classification(class_names=ckptdata["task"].class_names,
input_key=ckptdata["task"].input_key,
label_key=ckptdata["task"].label_key,
meta_keys=ckptdata["task"].meta_keys))
# move 'state_dict' field to 'model'
if "state_dict" in ckptdata:
ckptdata["model"] = ckptdata["state_dict"]
del ckptdata["state_dict"]
# create 'model_type' and 'model_params' fields
if "model" in new_config:
if "type" in new_config["model"]:
ckptdata["model_type"] = new_config["model"]["type"]
else:
ckptdata["model_type"] = None
if "params" in new_config["model"]:
ckptdata["model_params"] = copy.deepcopy(new_config["model"]["params"])
else:
ckptdata["model_params"] = {}
# TODO: create 'scheduler' field to restore previous state? (not so important for early versions)
# ckpt_ver = [0, 2, 0] # set ver for next update step
# if ckpt_ver[0] <= x and ckpt_ver[1] <= y and ckpt_ver[2] <= z:
# ... add more compatibility fixes here
ckptdata["config"] = new_config
return ckptdata
[docs]def migrate_config(config, # type: thelper.typedefs.ConfigDict
cfg_ver_str, # type: str
): # type: (...) -> thelper.typedefs.ConfigDict
"""Migrates the content of an incompatible or outdated configuration to the current version of the framework.
This function might not be able to fix all backward compatibility issues (e.g. it cannot fix class interfaces
that were changed). Perfect reproductibility of tests cannot be guaranteed either if this migration tool is used.
Args:
config: session configuration dictionary obtained e.g. by parsing a JSON file. Note that the data contained
in this dictionary will be modified in-place.
cfg_ver_str: string representing the version for which the configuration was created (e.g. "0.2.0").
Returns:
An updated configuration dictionary that should be compatible with the current version of the framework.
"""
if not isinstance(config, dict):
raise AssertionError("unexpected config type")
if not isinstance(cfg_ver_str, str) or len(cfg_ver_str.split(".")) != 3:
raise AssertionError("unexpected checkpoint version formatting")
from thelper import __version__ as curr_ver
curr_ver = [int(num) for num in curr_ver.split(".")]
cfg_ver = [int(num) for num in cfg_ver_str.split(".")]
if (cfg_ver[0] > curr_ver[0] or (cfg_ver[0] == curr_ver[0] and cfg_ver[1] > curr_ver[1]) or
(cfg_ver[0:2] == curr_ver[0:2] and cfg_ver[2] > curr_ver[2])):
raise AssertionError("cannot migrate configs from future versions!")
if cfg_ver == [0, 0, 0]:
logger.warning("trying to migrate config from v0.0.0; all bets are off")
else:
logger.info("trying to migrate config from v%s" % cfg_ver_str)
if cfg_ver[0] <= 0 and cfg_ver[1] < 1:
# must search for name-value parameter lists and convert them to dictionaries
def name_value_replacer(cfg):
if isinstance(cfg, dict):
for key, val in cfg.items():
if (key == "params" or key == "parameters") and isinstance(val, list) and \
all([isinstance(p, dict) and list(p.keys()) == ["name", "value"] for p in val]):
cfg["params"] = {param["name"]: name_value_replacer(param["value"]) for param in val}
if key == "parameters":
del cfg["parameters"]
elif isinstance(val, (dict, list)):
cfg[key] = name_value_replacer(val)
elif isinstance(cfg, list):
for idx, val in enumerate(cfg):
cfg[idx] = name_value_replacer(val)
return cfg
config = name_value_replacer(config)
# must replace "data_config" section by "loaders"
if "data_config" in config:
config["loaders"] = config["data_config"]
del config["data_config"]
# remove deprecated name attribute for models
if "model" in config and isinstance(config["model"], dict) and "name" in config["model"]:
del config["model"]["name"]
# must update import targets wrt class name refactorings
def import_refactoring(cfg): # noqa: E306
if isinstance(cfg, dict):
for key, val in cfg.items():
cfg[key] = import_refactoring(val)
elif isinstance(cfg, list):
for idx, val in enumerate(cfg):
cfg[idx] = import_refactoring(val)
elif isinstance(cfg, str) and cfg.startswith("thelper."):
cfg = thelper.utils.resolve_import(cfg)
return cfg
config = import_refactoring(config)
if "trainer" in config and isinstance(config["trainer"], dict):
trainer_cfg = config["trainer"]
# move 'loss' section to 'optimization' section
if "loss" in trainer_cfg:
if "optimization" not in trainer_cfg or not isinstance(trainer_cfg["optimization"], dict):
trainer_cfg["optimization"] = {}
trainer_cfg["optimization"]["loss"] = trainer_cfg["loss"]
del trainer_cfg["loss"]
# replace all devices with cuda:all
if "train_device" in trainer_cfg:
del trainer_cfg["train_device"]
if "valid_device" in trainer_cfg:
del trainer_cfg["valid_device"]
if "test_device" in trainer_cfg:
del trainer_cfg["test_device"]
if "device" not in trainer_cfg:
trainer_cfg["device"] = "cuda:all"
# remove params from trainer config
if "params" in trainer_cfg:
if not isinstance(trainer_cfg["params"], (dict, list)) or trainer_cfg["params"]:
logger.warning("removing non-empty parameter section from trainer config")
del trainer_cfg["params"]
cfg_ver = [0, 1, 0] # set ver for next update step
if cfg_ver[0] <= 0 and cfg_ver[1] <= 1:
# remove 'force_convert' flags from all transform pipelines + build augment pipeline wrappers
def remove_force_convert(cfg): # noqa: E306
if isinstance(cfg, list):
for idx, stage in enumerate(cfg):
cfg[idx] = remove_force_convert(stage)
elif isinstance(cfg, dict):
if "parameters" in cfg:
cfg["params"] = cfg["parameters"]
del cfg["parameters"]
if "operation" in cfg and cfg["operation"] == "thelper.transforms.TransformWrapper":
if "params" in cfg and "force_convert" in cfg["params"]:
del cfg["params"]["force_convert"]
for key, stage in cfg.items():
cfg[key] = remove_force_convert(stage)
return cfg
for pipeline in ["base_transforms", "train_augments", "valid_augments", "test_augments"]:
if "loaders" in config and isinstance(config["loaders"], dict) and pipeline in config["loaders"]:
if pipeline.endswith("_augments"):
stages = config["loaders"][pipeline]
for stage in stages:
if "append" in stage:
if stage["append"]:
logger.warning("overriding augmentation stage ordering")
del stage["append"]
if "operation" in stage and stage["operation"] == "Augmentor.Pipeline":
if "params" in stage:
stage["params"] = stage["params"]["operations"]
elif "parameters" in stage:
stage["params"] = stage["parameters"]["operations"]
del stage["parameters"]
config["loaders"][pipeline] = {"append": False, "transforms": remove_force_convert(stages)}
else:
config["loaders"][pipeline] = remove_force_convert(config["loaders"][pipeline])
cfg_ver = [0, 2, 0] # set ver for next update step
if cfg_ver[0] <= 0 and cfg_ver[1] <= 2 and cfg_ver[2] < 5:
# TODO: add scheduler 0-based step fix here? (unlikely to cause serious issues)
cfg_ver = [0, 2, 5] # set ver for next update step
if cfg_ver[0] <= 0 and cfg_ver[1] <= 3 and cfg_ver[2] < 6:
if "trainer" in config:
if "eval_metrics" in config["trainer"]:
assert "valid_metrics" not in config["trainer"]
config["trainer"]["valid_metrics"] = config["trainer"]["eval_metrics"]
del config["trainer"]["eval_metrics"]
for mset in ["train_metrics", "valid_metrics", "test_metrics", "metrics"]:
if mset in config["trainer"]:
metrics_config = config["trainer"][mset]
for mname, mcfg in metrics_config.items():
if "type" in mcfg and mcfg["type"].endswith("ExternalMetric"):
assert "params" in mcfg
assert "goal" in mcfg["params"]
mcfg["params"]["metric_goal"] = mcfg["params"]["goal"]
del mcfg["params"]["goal"]
if "metric_params" in mcfg["params"]:
if isinstance(mcfg["params"]["metric_params"], list):
assert not mcfg["params"]["metric_params"], "cannot fill in kw names"
mcfg["params"]["metric_params"] = {}
elif "type" in mcfg and mcfg["type"].endswith("ROCCurve"):
assert "params" in mcfg
if "log_params" in mcfg["params"]:
logger.warning("disabling logging via ROCCurve metric")
del mcfg["params"]["log_params"]
cfg_ver = [0, 3, 6] # set ver for next update step
# if cfg_ver[0] <= x and cfg_ver[1] <= y and cfg_ver[2] <= z:
# ... add more compatibility fixes here
return config
[docs]def download_file(url, root, filename, md5=None):
"""Downloads a file from a given URL to a local destination.
Args:
url: path to query for the file (query will be based on urllib).
root: destination folder where the file should be saved.
filename: destination name for the file.
md5: optional, for md5 integrity check.
Returns:
The path to the downloaded file.
"""
# inspired from torchvision.datasets.utils.download_url; no dep check
from six.moves import urllib
root = os.path.expanduser(root)
fpath = os.path.join(root, filename)
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
if not os.path.isfile(fpath):
logger.info("Downloading %s to %s ..." % (url, fpath))
urllib.request.urlretrieve(url, fpath, reporthook)
sys.stdout.write("\r")
sys.stdout.flush()
if md5 is not None:
import hashlib
md5o = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(1024 * 1024), b''):
md5o.update(chunk)
md5c = md5o.hexdigest()
if md5c != md5:
raise AssertionError("md5 check failed for '%s'" % fpath)
return fpath
[docs]def reporthook(count, block_size, total_size):
"""Report hook used to display a download progression bar when using urllib requests."""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = str(int(progress_size / (1024 * duration))) if duration > 0 else "?"
percent = min(int(count * block_size * 100 / total_size), 100)
sys.stdout.write("\r\t=> downloaded %d%% (%d MB) @ %s KB/s..." %
(percent, progress_size / (1024 * 1024), speed))
sys.stdout.flush()
[docs]def init_logger(log_level=logging.NOTSET, filename=None, force_stdout=False):
"""Initializes the framework logger with a specific filter level, and optional file output."""
logging.getLogger().setLevel(logging.NOTSET)
thelper.logger.propagate = 0
logger_format = logging.Formatter("[%(asctime)s - %(name)s] %(levelname)s : %(message)s")
if filename is not None:
logger_fh = logging.FileHandler(filename)
logger_fh.setLevel(logging.NOTSET)
logger_fh.setFormatter(logger_format)
thelper.logger.addHandler(logger_fh)
stream = sys.stdout if force_stdout else None
logger_ch = logging.StreamHandler(stream=stream)
logger_ch.setLevel(log_level)
logger_ch.setFormatter(logger_format)
thelper.logger.addHandler(logger_ch)
[docs]def resolve_import(fullname):
# type: (str) -> str
"""
Class name resolver.
Takes a string corresponding to a module and class fullname to be imported with :func:`thelper.utils.import_class`
and resolves any back compatibility issues related to renamed or moved classes.
Args:
fullname: the fully qualified class name to be resolved.
Returns:
The resolved class fullname.
"""
removed_cases = [
'thelper.optim.metrics.RawPredictions', # removed in 0.3.5
]
if fullname in removed_cases:
raise AssertionError(f"class {repr(fullname)} was deprecated and removed in a previous version")
refactor_cases = [
('thelper.modules', 'thelper.nn'),
('thelper.samplers', 'thelper.data.samplers'),
('thelper.optim.BinaryAccuracy', 'thelper.optim.metrics.Accuracy'),
('thelper.optim.CategoryAccuracy', 'thelper.optim.metrics.Accuracy'),
('thelper.optim.ClassifLogger', 'thelper.train.utils.ClassifLogger'),
('thelper.optim.ClassifReport', 'thelper.train.utils.ClassifReport'),
('thelper.optim.ConfusionMatrix', 'thelper.train.utils.ConfusionMatrix'),
('thelper.optim.metrics.BinaryAccuracy', 'thelper.optim.metrics.Accuracy'),
('thelper.optim.metrics.CategoryAccuracy', 'thelper.optim.metrics.Accuracy'),
('thelper.optim.metrics.ClassifLogger', 'thelper.train.utils.ClassifLogger'),
('thelper.optim.metrics.ClassifReport', 'thelper.train.utils.ClassifReport'),
('thelper.optim.metrics.ConfusionMatrix', 'thelper.train.utils.ConfusionMatrix'),
('thelper.transforms.ImageTransformWrapper', 'thelper.transforms.TransformWrapper'),
('thelper.transforms.wrappers.ImageTransformWrapper', 'thelper.transforms.wrappers.TransformWrapper'),
]
old_name = fullname
for old, new in refactor_cases:
fullname = fullname.replace(old, new)
if old_name != fullname:
logger.warning("class fullname '{!s}' was resolved to '{!s}'.".format(old_name, fullname))
return fullname
[docs]def import_class(fullname):
# type: (str) -> Type
"""General-purpose runtime class importer.
Supported syntax:
1. ``module.package.Class`` will import the fully qualified ``Class`` located
in ``package`` from the *installed* ``module``
2. ``/some/path/mod.pkg.Cls`` will import ``Cls`` as fully qualified ``mod.pkg.Cls`` from
``/some/path`` directory
Args:
fullname: the fully qualified class name to be imported.
Returns:
The imported class.
"""
assert isinstance(fullname, str)
fullname = pathlib.Path(fullname).as_posix()
if "/" in fullname:
mod_path, mod_cls_name = fullname.rsplit("/", 1)
pkg_name = mod_cls_name.rsplit(".", 1)[0]
pkg_file = os.path.join(mod_path, pkg_name.replace(".", "/")) + ".py"
mod_cls_name = resolve_import(mod_cls_name)
spec = importlib.util.spec_from_file_location(mod_cls_name, pkg_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
class_name = mod_cls_name.rsplit('.', 1)[-1]
else:
fullname = resolve_import(fullname)
module_name, class_name = fullname.rsplit('.', 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
[docs]def import_function(fullname, params=None):
# type: (str, Optional[thelper.typedefs.ConfigDict]) -> FunctionType
"""General-purpose runtime function importer, with support for param binding.
Args:
fullname: the fully qualified function name to be imported.
params: optional params dictionary to bind to the function call via functools.
Returns:
The imported function, with optionally bound parameters.
"""
func = import_class(fullname)
if params is not None:
if not isinstance(params, dict):
raise AssertionError("unexpected params dict type")
return functools.partial(func, **params)
return func
[docs]def check_func_signature(func, # type: FunctionType
params # type: List[str]
): # type: (...) -> None
"""Checks whether the signature of a function matches the expected parameter list."""
if func is None or not callable(func):
raise AssertionError("invalid function object")
if params is not None:
if not isinstance(params, list) or not all([isinstance(p, str) for p in params]):
raise AssertionError("unexpected param name list format")
import inspect
func_sig = inspect.signature(func)
for p in params:
if p not in func_sig.parameters:
raise AssertionError("function missing parameter '%s'" % p)
[docs]def encode_data(data, approach="lz4", **kwargs):
"""Encodes a numpy array using a given coding approach.
Args:
data: the numpy array to encode.
approach: the encoding; supports `none`, `lz4`, `jpg`, `png`.
.. seealso::
| :func:`thelper.utils.decode_data`
"""
supported_approaches = ["none", "lz4", "jpg", "png"]
if approach not in supported_approaches:
raise AssertionError(f"unexpected approach type (got '{approach}')")
if approach == "none":
return data
elif approach == "lz4":
return lz4.frame.compress(data, **kwargs)
elif approach == "jpg" or approach == "jpeg":
ret, buf = cv.imencode(".jpg", data, **kwargs)
elif approach == "png":
ret, buf = cv.imencode(".png", data, **kwargs)
else:
raise NotImplementedError
if not ret:
raise AssertionError("failed to encode data")
return buf
[docs]def decode_data(data, approach="lz4", **kwargs):
"""Decodes a binary array using a given coding approach.
Args:
data: the binary array to decode.
approach: the encoding; supports `none`, `lz4`, `jpg`, `png`.
.. seealso::
| :func:`thelper.utils.encode_data`
"""
supported_approach_types = ["none", "lz4", "jpg", "png"]
if approach not in supported_approach_types:
raise AssertionError(f"unexpected approach type (got '{approach}')")
if approach == "none":
return data
elif approach == "lz4":
return lz4.frame.decompress(data, **kwargs)
elif approach in ["jpg", "jpeg", "png"]:
kwargs = copy.deepcopy(kwargs)
if isinstance(kwargs["flags"], str): # required arg by opencv
kwargs["flags"] = eval(kwargs["flags"])
return cv.imdecode(data, **kwargs)
else:
raise NotImplementedError
[docs]def get_class_logger(skip=0):
"""Shorthand to get logger for current class frame."""
return logging.getLogger(get_caller_name(skip + 1).rsplit(".", 1)[0])
[docs]def get_func_logger(skip=0):
"""Shorthand to get logger for current function frame."""
return logging.getLogger(get_caller_name(skip + 1))
[docs]def get_caller_name(skip=2):
# source: https://gist.github.com/techtonik/2151727
"""Returns the name of a caller in the format module.class.method.
Args:
skip: specifies how many levels of stack to skip while getting the caller.
Returns:
An empty string is returned if skipped levels exceed stack height; otherwise,
returns the requested caller name.
"""
def stack_(frame):
frame_list = []
while frame:
frame_list.append(frame)
frame = frame.f_back
return frame_list
# noinspection PyProtectedMember
stack = stack_(sys._getframe(1))
start = 0 + skip
if len(stack) < start + 1:
return ""
parent_frame = stack[start]
name = []
module = inspect.getmodule(parent_frame)
# `modname` can be None when frame is executed directly in console
if module:
name.append(module.__name__)
# detect class name
if "self" in parent_frame.f_locals:
# I don't know any way to detect call from the object method
# XXX: there seems to be no way to detect static method call - it will
# be just a function call
name.append(parent_frame.f_locals["self"].__class__.__name__)
codename = parent_frame.f_code.co_name
if codename != "<module>": # top level usually
name.append(codename) # function or a method
del parent_frame
return ".".join(name)
[docs]def get_key(key, config, msg=None, delete=False):
"""Returns a value given a dictionary key, throwing if not available."""
if isinstance(key, list):
if len(key) <= 1:
if msg is not None:
raise AssertionError(msg)
else:
raise AssertionError("must provide at least two valid keys to test")
for k in key:
if k in config:
val = config[k]
if delete:
del config[k]
return val
if msg is not None:
raise AssertionError(msg)
else:
raise AssertionError("config dictionary missing a field named as one of '%s'" % str(key))
else:
if key not in config:
if msg is not None:
raise AssertionError(msg)
else:
raise AssertionError("config dictionary missing '%s' field" % key)
else:
val = config[key]
if delete:
del config[key]
return val
[docs]def get_key_def(key, config, default=None, msg=None, delete=False):
"""Returns a value given a dictionary key, or the default value if it cannot be found."""
if isinstance(key, list):
if len(key) <= 1:
if msg is not None:
raise AssertionError(msg)
else:
raise AssertionError("must provide at least two valid keys to test")
for k in key:
if k in config:
val = config[k]
if delete:
del config[k]
return val
return default
else:
if key not in config:
return default
else:
val = config[key]
if delete:
del config[key]
return val
[docs]def get_log_stamp():
"""Returns a print-friendly and filename-friendly identification string containing platform and time."""
return str(platform.node()) + "-" + time.strftime("%Y%m%d-%H%M%S")
[docs]def get_git_stamp():
"""Returns a print-friendly SHA signature for the framework's underlying git repository (if found)."""
try:
import git
try:
repo = git.Repo(path=os.path.abspath(__file__), search_parent_directories=True)
sha = repo.head.object.hexsha
return str(sha)
except (AttributeError, git.InvalidGitRepositoryError):
return "unknown"
except (ImportError, AttributeError):
return "unknown"
[docs]def get_env_list():
"""Returns a list of all packages installed in the current environment.
If the required packages cannot be imported, the returned list will be empty. Note that some
packages may not be properly detected by this approach, and it is pretty hacky, so use it with
a grain of salt (i.e. logging is fine).
"""
try:
import pip
# noinspection PyUnresolvedReferences
pkgs = pip.get_installed_distributions()
return sorted(["%s %s" % (pkg.key, pkg.version) for pkg in pkgs])
except (ImportError, AttributeError):
try:
import pkg_resources as pkgr
return sorted([str(pkg) for pkg in pkgr.working_set])
except (ImportError, AttributeError):
return []
[docs]def str2size(input_str):
"""Returns a (WIDTH, HEIGHT) integer size tuple from a string formatted as 'WxH'."""
if not isinstance(input_str, str):
raise AssertionError("unexpected input type")
display_size_str = input_str.split('x')
if len(display_size_str) != 2:
raise AssertionError("bad size string formatting")
return tuple([max(int(substr), 1) for substr in display_size_str])
[docs]def str2bool(s):
"""Converts a string to a boolean.
If the lower case version of the provided string matches any of 'true', '1', or
'yes', then the function returns ``True``.
"""
if isinstance(s, bool):
return s
if isinstance(s, (int, float)):
return s != 0
if isinstance(s, str):
positive_flags = ["true", "1", "yes"]
return s.lower() in positive_flags
raise AssertionError("unrecognized input type")
[docs]def clipstr(s, size, fill=" "):
"""Clips a string to a specific length, with an optional fill character."""
if len(s) > size:
s = s[:size]
if len(s) < size:
s = fill * (size - len(s)) + s
return s
[docs]def lreplace(string, old_prefix, new_prefix):
"""Replaces a single occurrence of `old_prefix` in the given string by `new_prefix`."""
return re.sub(r'^(?:%s)+' % re.escape(old_prefix), lambda m: new_prefix * (m.end() // len(old_prefix)), string)
[docs]def query_yes_no(question, default=None, bypass=None):
"""Asks the user a yes/no question and returns the answer.
Args:
question: the string that is presented to the user.
default: the presumed answer if the user just hits ``<Enter>``. It must be 'yes',
'no', or ``None`` (meaning an answer is required).
bypass: the option to select if the ``bypass_queries`` global variable is set to
``True``. Can be ``None``, in which case the function will throw an exception.
Returns:
``True`` for 'yes', or ``False`` for 'no' (or their respective variations).
"""
valid = {"yes": True, "ye": True, "y": True, "no": False, "n": False}
if bypass is not None and (not isinstance(bypass, str) or bypass not in valid):
raise AssertionError("unexpected bypass value")
if bypass_queries:
if bypass is None:
raise AssertionError("cannot bypass interactive query, no default value provided")
return valid[bypass]
if (isinstance(default, bool) and default) or \
(isinstance(default, str) and default.lower() in ["yes", "ye", "y"]):
prompt = " [Y/n] "
elif (isinstance(default, bool) and not default) or \
(isinstance(default, str) and default.lower() in ["no", "n"]):
prompt = " [y/N] "
else:
prompt = " [y/n] "
sys.stdout.flush()
sys.stderr.flush()
time.sleep(0.25) # to make sure all debug/info prints are done, and we see the question
while True:
sys.stdout.write(question + prompt + "\n>> ")
choice = input().lower()
if default is not None and choice == "":
if isinstance(default, str):
return valid[default]
else:
return default
elif choice in valid:
return valid[choice]
else:
sys.stdout.write("Please respond with 'yes/y' or 'no/n'.\n")
[docs]def query_string(question, choices=None, default=None, allow_empty=False, bypass=None):
"""Asks the user a question and returns the answer (a generic string).
Args:
question: the string that is presented to the user.
choices: a list of predefined choices that the user can pick from. If
``None``, then whatever the user types will be accepted.
default: the presumed answer if the user just hits ``<Enter>``. If ``None``,
then an answer is required to continue.
allow_empty: defines whether an empty answer should be accepted.
bypass: the returned value if the ``bypass_queries`` global variable is set to
``True``. Can be ``None``, in which case the function will throw an exception.
Returns:
The string entered by the user.
"""
if bypass_queries:
if bypass is None:
raise AssertionError("cannot bypass interactive query, no default value provided")
return bypass
sys.stdout.flush()
sys.stderr.flush()
time.sleep(0.25) # to make sure all debug/info prints are done, and we see the question
while True:
msg = question
if choices is not None:
msg += "\n\t(choices=%s)" % str(choices)
if default is not None:
msg += "\n\t(default=%s)" % default
sys.stdout.write(msg + "\n>> ")
answer = input()
if answer == "":
if default is not None:
return default
elif allow_empty:
return answer
elif choices is not None:
if answer in choices:
return answer
else:
return answer
sys.stdout.write("Please respond with a valid string.\n")
[docs]def get_save_dir(out_root, dir_name, config=None, resume=False, backup_ext=".json"):
"""Returns a directory path in which the app can save its data.
If a folder with name ``dir_name`` already exists in the directory ``out_root``, then the user will be
asked to pick a new name. If the user refuses, ``sys.exit(1)`` is called. If config is not ``None``, it
will be saved to the output directory as a json file. Finally, a ``logs`` directory will also be created
in the output directory for writing logger files.
Args:
out_root: path to the directory root where the save directory should be created.
dir_name: name of the save directory to create. If it already exists, a new one will be requested.
config: dictionary of app configuration parameters. Used to overwrite i/o queries, and will be
written to the save directory in json format to test writing. Default is ``None``.
resume: specifies whether this session is new, or resumed from an older one (in the latter
case, overwriting is allowed, and the user will never have to choose a new folder)
backup_ext: extension to use when creating configuration file backups.
Returns:
The path to the created save directory for this session.
"""
func_logger = get_func_logger()
save_dir = out_root
if save_dir is None:
time.sleep(0.25) # to make sure all debug/info prints are done, and we see the question
save_dir = query_string("Please provide the path to where session directories should be created/saved:")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_dir = os.path.join(save_dir, dir_name)
if not resume:
overwrite = str2bool(config["overwrite"]) if config is not None and "overwrite" in config else False
time.sleep(0.25) # to make sure all debug/info prints are done, and we see the question
while os.path.exists(save_dir) and not overwrite:
abs_save_dir = os.path.abspath(save_dir).replace("\\", "/")
overwrite = query_yes_no("Training session at '%s' already exists; overwrite?" % abs_save_dir, bypass="y")
if not overwrite:
save_dir = query_string("Please provide a new save directory path:")
if not os.path.exists(save_dir):
os.mkdir(save_dir)
if config is not None:
save_config(config, os.path.join(save_dir, "config.latest" + backup_ext))
else:
if not os.path.exists(save_dir):
os.mkdir(save_dir)
if config is not None:
backup_path = os.path.join(save_dir, "config.latest" + backup_ext)
if os.path.exists(backup_path):
with open(backup_path, "r") as fd:
config_backup = json.load(fd)
if config_backup != config:
query_msg = f"Config backup in '{backup_path}' differs from config loaded through checkpoint; overwrite?"
answer = query_yes_no(query_msg, bypass="y")
if answer:
func_logger.warning("config mismatch with previous run; "
"will overwrite latest backup in save directory")
else:
func_logger.error("config mismatch with previous run; user aborted")
sys.exit(1)
save_config(config, backup_path)
logs_dir = os.path.join(save_dir, "logs")
if not os.path.exists(logs_dir):
os.mkdir(logs_dir)
save_config(config, os.path.join(logs_dir, "config." + thelper.utils.get_log_stamp() + backup_ext))
return save_dir
[docs]def save_config(config, path, force_convert=True):
"""Saves the given session/object configuration dictionary to the provided path.
The type of file that is created is based on the extension specified in the path. If the file
cannot hold some of the objects within the configuration, they will be converted to strings before
serialization, unless `force_convert` is set to `False` (in which case the function will raise
an exception).
Args:
config: the session/object configuration dictionary to save.
path: the path specifying where to create the output file. The extension used will determine
what type of backup to create (e.g. Pickle = .pkl, JSON = .json).
force_convert: specifies whether non-serializable types should be converted if necessary.
"""
if path.endswith(".json"):
serializer = (lambda x: str(x)) if force_convert else None
with open(path, "w") as fd:
json.dump(config, fd, indent=4, sort_keys=False, default=serializer)
elif path.endswith(".pkl"):
with open(path, "w") as fd:
pickle.dump(config, fd)
else:
raise AssertionError("unknown output file type")
[docs]def save_env_list(path):
"""Saves a list of all packages installed in the current environment to a log file.
Args:
path: the path where the log file should be created.
"""
with open(path, "w") as fd:
pkgs_list = thelper.utils.get_env_list()
if pkgs_list:
for pkg in pkgs_list:
fd.write("%s\n" % pkg)
else:
fd.write("<n/a>\n")
[docs]def safe_crop(image, tl, br, bordertype=cv.BORDER_CONSTANT, borderval=0, force_copy=False):
"""Safely crops a region from within an image, padding borders if needed.
Args:
image: the image to crop (provided as a numpy array).
tl: a tuple or list specifying the (x,y) coordinates of the top-left crop corner.
br: a tuple or list specifying the (x,y) coordinates of the bottom-right crop corner.
bordertype: border copy type to use when the image is too small for the required crop size.
See ``cv2.copyMakeBorder`` for more information.
borderval: border value to use when the image is too small for the required crop size. See
``cv2.copyMakeBorder`` for more information.
force_copy: defines whether to force a copy of the target image region even when it can be
avoided.
Returns:
The cropped image.
"""
if not isinstance(image, np.ndarray):
raise AssertionError("expected input image to be numpy array")
if isinstance(tl, tuple):
tl = list(tl)
if isinstance(br, tuple):
br = list(br)
if not isinstance(tl, list) or not isinstance(br, list):
raise AssertionError("expected tl/br coords to be provided as tuple or list")
if tl[0] < 0 or tl[1] < 0 or br[0] > image.shape[1] or br[1] > image.shape[0]:
image = cv.copyMakeBorder(image, max(-tl[1], 0), max(br[1] - image.shape[0], 0),
max(-tl[0], 0), max(br[0] - image.shape[1], 0),
borderType=bordertype, value=borderval)
if tl[0] < 0:
br[0] -= tl[0]
tl[0] = 0
if tl[1] < 0:
br[1] -= tl[1]
tl[1] = 0
return image[tl[1]:br[1], tl[0]:br[0], ...]
if force_copy:
return np.copy(image[tl[1]:br[1], tl[0]:br[0], ...])
return image[tl[1]:br[1], tl[0]:br[0], ...]
[docs]def get_bgr_from_hsl(hue, sat, light):
"""Converts a single HSL triplet (0-360 hue, 0-1 sat & lightness) into an 8-bit RGB triplet."""
# this function is not intended for fast conversions; use OpenCV's cvtColor for large-scale stuff
if hue < 0 or hue > 360:
raise AssertionError("invalid hue")
if sat < 0 or sat > 1:
raise AssertionError("invalid saturation")
if light < 0 or light > 1:
raise AssertionError("invalid lightness")
if sat == 0:
return (int(np.clip(round(light * 255), 0, 255)),) * 3
if light == 0:
return 0, 0, 0
if light == 1:
return 255, 255, 255
def h2rgb(_p, _q, _t):
if _t < 0:
_t += 1
if _t > 1:
_t -= 1
if _t < 1 / 6:
return _p + (_q - _p) * 6 * _t
if _t < 1 / 2:
return _q
if _t < 2 / 3:
return _p + (_q - _p) * (2 / 3 - _t) * 6
return _p
q = light * (1 + sat) if (light < 0.5) else light + sat - light*sat
p = 2 * light - q
h = hue / 360
return (int(np.clip(round(h2rgb(p, q, h - 1 / 3) * 255), 0, 255)),
int(np.clip(round(h2rgb(p, q, h) * 255), 0, 255)),
int(np.clip(round(h2rgb(p, q, h + 1 / 3) * 255), 0, 255)))
[docs]def get_displayable_image(image, # type: thelper.typedefs.ArrayType
grayscale=False, # type: Optional[bool]
): # type: (...) -> thelper.typedefs.ArrayType
"""Returns a 'displayable' image that has been normalized and padded to three channels."""
if image.ndim != 3:
raise AssertionError("indexing should return a pre-squeezed array")
if image.shape[2] == 2:
image = np.dstack((image, image[:, :, 0]))
elif image.shape[2] > 3:
image = image[..., :3]
if grayscale and image.shape[2] != 1:
image = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
elif not grayscale and image.shape[2] == 1:
image = cv.cvtColor(image, cv.COLOR_GRAY2BGR)
image_normalized = np.empty_like(image, dtype=np.uint8).copy() # copy needed here due to ocv 3.3 bug
cv.normalize(image, image_normalized, 0, 255, cv.NORM_MINMAX, dtype=cv.CV_8U)
return image_normalized
[docs]def get_displayable_heatmap(array, # type: thelper.typedefs.ArrayType
convert_rgb=True, # type: Optional[bool]
): # type: (...) -> thelper.typedefs.ArrayType
"""Returns a 'displayable' array that has been min-maxed and mapped to color triplets."""
if array.ndim != 2:
array = np.squeeze(array)
if array.ndim != 2:
raise AssertionError("indexing should return a pre-squeezed array")
array_normalized = np.empty_like(array, dtype=np.uint8).copy() # copy needed here due to ocv 3.3 bug
cv.normalize(array, array_normalized, 0, 255, cv.NORM_MINMAX, dtype=cv.CV_8U)
heatmap = cv.applyColorMap(array_normalized, cv.COLORMAP_JET)
if convert_rgb:
heatmap = cv.cvtColor(heatmap, cv.COLOR_BGR2RGB)
return heatmap
[docs]def is_scalar(val):
"""Returns whether the input value is a scalar according to numpy and PyTorch."""
if np.isscalar(val):
return True
if isinstance(val, torch.Tensor) and (val.dim() == 0 or val.numel() == 1):
return True
return False
[docs]def to_numpy(array):
"""Converts a list or PyTorch tensor to numpy. Does nothing if already a numpy array."""
if isinstance(array, list):
return np.asarray(array)
elif isinstance(array, torch.Tensor):
return array.cpu().numpy()
elif isinstance(array, np.ndarray):
return array
else:
raise AssertionError(f"unexpected input type ({type(array)})")
[docs]def draw_histogram(data, # type: thelper.typedefs.ArrayType
bins=50, # type: Optional[int]
xlabel="", # type: Optional[thelper.typedefs.LabelType]
ylabel="Proportion", # type: Optional[thelper.typedefs.LabelType]
show=False, # type: Optional[bool]
block=False, # type: Optional[bool]
): # type: (...) -> thelper.typedefs.DrawingType
"""Draws and returns a histogram figure using pyplot."""
fig, ax = plt.subplots()
ax.hist(data, density=True, bins=bins)
if len(ylabel) > 0:
ax.set_ylabel(ylabel)
if len(xlabel) > 0:
ax.set_xlabel(xlabel)
ax.set_xlim(xmin=0)
if show:
fig.show()
if block:
plt.show(block=block)
return fig
plt.pause(0.5)
return fig, ax
[docs]def draw_popbars(labels, # type: thelper.typedefs.LabelList
counts, # type: int
xlabel="", # type: Optional[thelper.typedefs.LabelType]
ylabel="Pop. Count", # type: Optional[thelper.typedefs.LabelType]
show=False, # type: Optional[bool]
block=False, # type: Optional[bool]
): # type: (...) -> thelper.typedefs.DrawingType
"""Draws and returns a bar histogram figure using pyplot."""
fig, ax = plt.subplots()
xrange = range(len(labels))
ax.bar(xrange, counts, align="center")
if len(ylabel) > 0:
ax.set_ylabel(ylabel)
if len(xlabel) > 0:
ax.set_xlabel(xlabel)
ax.set_xticks(xrange)
ax.set_xticklabels(labels)
ax.tick_params(axis="x", labelsize="8", labelrotation=45)
if show:
fig.show()
if block:
plt.show(block=block)
return fig
plt.pause(0.5)
return fig, ax
[docs]def draw_pascalvoc_curve(metrics, size_inch=(5, 5), dpi=320, show=False, block=False):
"""Draws and returns a precision-recall curve according to pascalvoc metrics."""
# note: the 'metrics' must correspond to a single class output produced by pascalvoc evaluator
assert isinstance(metrics, dict), "unexpected metrics format"
class_name = metrics["class_name"]
assert isinstance(class_name, str), "unexpected class name type"
iou_threshold = metrics["iou_threshold"]
assert 0 < iou_threshold <= 1, "invalid intersection over union value (should be in ]0,1])"
method = metrics["eval_method"]
assert method in ["all-points", "11-points"], "invalid method (should be 'all-points' or '11-points')"
fig = plt.figure(num="pr", figsize=size_inch, dpi=dpi, facecolor="w", edgecolor="k")
fig.clf()
ax = fig.add_subplot(1, 1, 1)
ax.plot(metrics["recall"], metrics["precision"],
label=f"{class_name} (AP={metrics['AP'] * 100:.2f}%)")
ax.set_xlabel("recall")
ax.set_ylabel("precision")
ax.set_title(f"PascalVOC PR Curve @ {iou_threshold} IoU")
ax.legend(loc="upper right")
ax.grid()
fig.set_tight_layout(True)
if show:
fig.show()
if block:
plt.show(block=block)
return fig
plt.pause(0.5)
return fig, ax
[docs]def draw_images(images, # type: thelper.typedefs.OneOrManyArrayType
captions=None, # type: Optional[List[str]]
redraw=None, # type: Optional[thelper.typedefs.DrawingType]
show=True, # type: Optional[bool]
block=False, # type: Optional[bool]
use_cv2=True, # type: Optional[bool]
cv2_flip_bgr=True, # type: Optional[bool]
img_shape=None, # type: Optional[thelper.typedefs.ArrayShapeType]
max_img_size=None, # type: Optional[thelper.typedefs.ArrayShapeType]
grid_size_x=None, # type: Optional[int]
grid_size_y=None, # type: Optional[int]
caption_opts=None,
window_name=None, # type: Optional[str]
): # type: (...) -> thelper.typedefs.DrawingType
"""Draws a set of images with optional captions."""
nb_imgs = len(images) if isinstance(images, list) else images.shape[0]
if nb_imgs < 1:
return None
assert captions is None or len(captions) == nb_imgs, "captions count mismatch with image count"
# for display on typical monitors... (height, width)
max_img_size = (800, 1600) if max_img_size is None else max_img_size
grid_size_x = int(math.ceil(math.sqrt(nb_imgs))) if grid_size_x is None else grid_size_x
grid_size_y = int(math.ceil(nb_imgs / grid_size_x)) if grid_size_y is None else grid_size_y
assert grid_size_x * grid_size_y >= nb_imgs, f"bad gridding for subplots (need at least {nb_imgs} tiles)"
if use_cv2:
if caption_opts is None:
caption_opts = {
"org": (10, 40),
"fontFace": cv.FONT_HERSHEY_SIMPLEX,
"fontScale": 0.40,
"color": (255, 255, 255),
"thickness": 1,
"lineType": cv.LINE_AA
}
if window_name is None:
window_name = "images"
img_grid_shape = None
img_grid = None if redraw is None else redraw[1]
for img_idx in range(nb_imgs):
image = images[img_idx] if isinstance(images, list) else images[img_idx, ...]
if img_shape is None:
img_shape = image.shape
if img_grid_shape is None:
img_grid_shape = (img_shape[0] * grid_size_y, img_shape[1] * grid_size_x, img_shape[2])
if img_grid is None or img_grid.shape != img_grid_shape:
img_grid = np.zeros(img_grid_shape, dtype=np.uint8)
if image.shape[2] != img_shape[2]:
raise AssertionError(f"unexpected image depth ({image.shape[2]} vs {img_shape[2]})")
if image.shape != img_shape:
image = cv.resize(image, (img_shape[1], img_shape[0]), interpolation=cv.INTER_NEAREST)
if captions is not None and str(captions[img_idx]):
image = cv.putText(image.copy(), str(captions[img_idx]), **caption_opts)
offsets = (img_idx // grid_size_x) * img_shape[0], (img_idx % grid_size_x) * img_shape[1]
np.copyto(img_grid[offsets[0]:(offsets[0] + img_shape[0]), offsets[1]:(offsets[1] + img_shape[1]), :], image)
win_name = str(window_name) if redraw is None else redraw[0]
if img_grid is not None:
display = img_grid[..., ::-1] if cv2_flip_bgr else img_grid
if display.shape[0] > max_img_size[0] or display.shape[1] > max_img_size[1]:
if display.shape[0] / max_img_size[0] > display.shape[1] / max_img_size[1]:
dsize = (max_img_size[0], int(round(display.shape[1] / (display.shape[0] / max_img_size[0]))))
else:
dsize = (int(round(display.shape[0] / (display.shape[1] / max_img_size[1]))), max_img_size[1])
display = cv.resize(display, (dsize[1], dsize[0]))
if show:
cv.imshow(win_name, display)
cv.waitKey(0 if block else 1)
return win_name, img_grid
else:
fig, axes = redraw if redraw is not None else plt.subplots(grid_size_y, grid_size_x)
if nb_imgs == 1:
axes = np.array(axes)
for ax_idx, ax in enumerate(axes.reshape(-1)):
if ax_idx < nb_imgs:
image = images[ax_idx] if isinstance(images, list) else images[ax_idx, ...]
if image.shape != img_shape:
image = cv.resize(image, (img_shape[1], img_shape[0]), interpolation=cv.INTER_NEAREST)
ax.imshow(image, interpolation='nearest')
if captions is not None and str(captions[ax_idx]):
ax.set_xlabel(str(captions[ax_idx]))
ax.set_xticks([])
ax.set_yticks([])
fig.set_tight_layout(True)
if show:
fig.show()
if block:
plt.show(block=block)
return None
plt.pause(0.5)
return fig, axes
[docs]def draw_predicts(images, preds=None, targets=None, swap_channels=False, redraw=None, block=False, **kwargs):
"""Draws and returns a set of generic prediction results."""
image_list = [get_displayable_image(images[batch_idx, ...]) for batch_idx in range(images.shape[0])]
image_gray_list = [cv.cvtColor(cv.cvtColor(image, cv.COLOR_BGR2GRAY), cv.COLOR_GRAY2BGR) for image in image_list]
nb_imgs = len(image_list)
caption_list = [""] * nb_imgs
grid_size_x, grid_size_y = nb_imgs, 1 # all images on one row, by default (add gt and preds as extra rows)
if targets is not None:
if not isinstance(targets, list) and not (isinstance(targets, torch.Tensor) and targets.shape[0] == nb_imgs):
raise AssertionError("expected targets to be in list or tensor format (Bx...)")
if isinstance(targets, list):
if all([isinstance(t, list) for t in targets]):
targets = list(itertools.chain.from_iterable(targets)) # merge all augmented lists together
targets = torch.cat(targets, 0) # merge all masks into a single tensor
if targets.shape[0] != nb_imgs:
raise AssertionError("images/targets count mismatch")
targets = targets.numpy()
if swap_channels:
if not targets.ndim == 4:
raise AssertionError("unexpected swap for targets tensor that is not 4-dim")
targets = np.transpose(targets, (0, 2, 3, 1)) # BxCxHxW to BxHxWxC
if ((targets.ndim == 4 and targets.shape[1] == 1) or targets.ndim == 3) and targets.shape[-2:] == images.shape[1:3]:
target_list = [get_displayable_heatmap(targets[batch_idx, ...]) for batch_idx in range(nb_imgs)]
target_list = [cv.addWeighted(image_gray_list[idx], 0.3, target_list[idx], 0.7, 0) for idx in range(nb_imgs)]
image_list += target_list
caption_list += [""] * nb_imgs
grid_size_y += 1
elif targets.shape == images.shape:
image_list += [get_displayable_image(targets[batch_idx, ...]) for batch_idx in range(nb_imgs)]
caption_list += [""] * nb_imgs
grid_size_y += 1
else:
for idx in range(nb_imgs):
caption_list[idx] = f"GT={str(targets[idx])}"
if preds is not None:
if not isinstance(preds, list) and not (isinstance(preds, torch.Tensor) and preds.shape[0] == nb_imgs):
raise AssertionError("expected preds to be in list or tensor shape (Bx...)")
if isinstance(preds, list):
if all([isinstance(p, list) for p in preds]):
preds = list(itertools.chain.from_iterable(preds)) # merge all augmented lists together
preds = torch.cat(preds, 0) # merge all preds into a single tensor
if preds.shape[0] != nb_imgs:
raise AssertionError("images/preds count mismatch")
preds = preds.numpy()
if swap_channels:
if not preds.ndim == 4:
raise AssertionError("unexpected swap for targets tensor that is not 4-dim")
preds = np.transpose(preds, (0, 2, 3, 1)) # BxCxHxW to BxHxWxC
if targets is not None and preds.shape != targets.shape:
raise AssertionError("preds/targets shape mismatch")
if ((preds.ndim == 4 and preds.shape[1] == 1) or preds.ndim == 3) and preds.shape[-2:] == images.shape[1:3]:
pred_list = [get_displayable_heatmap(preds[batch_idx, ...]) for batch_idx in range(nb_imgs)]
pred_list = [cv.addWeighted(image_gray_list[idx], 0.3, pred_list[idx], 0.7, 0) for idx in range(nb_imgs)]
image_list += pred_list
caption_list += [""] * nb_imgs
grid_size_y += 1
elif preds.shape == images.shape:
image_list += [get_displayable_image(preds[batch_idx, ...]) for batch_idx in range(nb_imgs)]
caption_list += [""] * nb_imgs
grid_size_y += 1
else:
for idx in range(nb_imgs):
if len(caption_list[idx]) != 0:
caption_list[idx] += ", "
caption_list[idx] = f"Pred={str(preds[idx])}"
return draw_images(image_list, captions=caption_list, redraw=redraw, window_name="predictions",
block=block, grid_size_x=grid_size_x, grid_size_y=grid_size_y, **kwargs)
[docs]def draw_segments(images, preds=None, masks=None, color_map=None, redraw=None, block=False, **kwargs):
"""Draws and returns a set of segmentation results."""
image_list = [get_displayable_image(images[batch_idx, ...]) for batch_idx in range(images.shape[0])]
image_gray_list = [cv.cvtColor(cv.cvtColor(image, cv.COLOR_BGR2GRAY), cv.COLOR_GRAY2BGR) for image in image_list]
nb_imgs = len(image_list)
grid_size_x, grid_size_y = nb_imgs, 1 # all images on one row, by default (add gt and preds as extra rows)
if color_map is not None and isinstance(color_map, dict):
assert len(color_map) <= 256, "too many indices for uint8 map"
use_alpha = all([isinstance(val, np.ndarray) and val.dtype in (np.float32, np.float64)
for val in color_map.values()])
color_map_new = np.zeros((256, 1, 3), dtype=np.float32 if use_alpha else np.uint8)
for idx, val in color_map.items():
color_map_new[idx, ...] = val
color_map = color_map_new
if masks is not None:
if not isinstance(masks, list) and not (isinstance(masks, torch.Tensor) and masks.dim() == 3):
raise AssertionError("expected segmentation masks to be in list or 3-d tensor format (BxHxW)")
if isinstance(masks, list):
if all([isinstance(m, list) for m in masks]):
masks = list(itertools.chain.from_iterable(masks)) # merge all augmented lists together
masks = torch.cat(masks, 0) # merge all masks into a single tensor
if masks.shape[0] != nb_imgs:
raise AssertionError("images/masks count mismatch")
if images.shape[0:3] != masks.shape:
raise AssertionError("images/masks shape mismatch")
masks = masks.numpy()
if color_map is not None:
masks = [apply_color_map(masks[idx], color_map) for idx in range(masks.shape[0])]
image_list += [cv.addWeighted(image_gray_list[idx], 0.3, masks[idx], 0.7, 0)
if masks[idx].dtype == np.uint8 else (image_list[idx] * masks[idx]).astype(np.uint8)
for idx in range(nb_imgs)]
grid_size_y += 1
if preds is not None:
if not isinstance(preds, list) and not (isinstance(preds, torch.Tensor) and preds.dim() == 4):
raise AssertionError("expected segmentation preds to be in list or 3-d tensor format (BxCxHxW)")
if isinstance(preds, list):
if all([isinstance(p, list) for p in preds]):
preds = list(itertools.chain.from_iterable(preds)) # merge all augmented lists together
preds = torch.cat(preds, 0) # merge all preds into a single tensor
with torch.no_grad():
preds = torch.squeeze(preds.topk(k=1, dim=1)[1], dim=1) # keep top prediction index only
if preds.shape[0] != nb_imgs:
raise AssertionError("images/preds count mismatch")
if images.shape[0:3] != preds.shape:
raise AssertionError("images/preds shape mismatch")
preds = preds.numpy()
if color_map is not None:
preds = [apply_color_map(preds[idx], color_map) for idx in range(preds.shape[0])]
image_list += [cv.addWeighted(image_gray_list[idx], 0.3, preds[idx], 0.7, 0)
if preds[idx].dtype == np.uint8 else (image_list[idx] * preds[idx]).astype(np.uint8)
for idx in range(nb_imgs)]
grid_size_y += 1
return draw_images(image_list, redraw=redraw, window_name="segments", block=block,
grid_size_x=grid_size_x, grid_size_y=grid_size_y, **kwargs)
[docs]def draw_classifs(images, preds=None, labels=None, class_names_map=None, redraw=None, block=False, **kwargs):
"""Draws and returns a set of classification results."""
image_list = [get_displayable_image(images[batch_idx, ...]) for batch_idx in range(images.shape[0])]
caption_list = [""] * len(image_list)
if labels is not None: # convert labels to flat list, if available
if not isinstance(labels, list) and not (isinstance(labels, torch.Tensor) and labels.dim() == 1):
raise AssertionError("expected classification labels to be in list or 1-d tensor format")
if isinstance(labels, list):
if all([isinstance(l, list) for l in labels]):
labels = list(itertools.chain.from_iterable(labels)) # merge all augmented lists together
if all([isinstance(t, torch.Tensor) for t in labels]):
labels = torch.cat(labels, 0)
if isinstance(labels, torch.Tensor):
labels = labels.tolist()
if images.shape[0] != len(labels):
raise AssertionError("images/labels count mismatch")
if class_names_map is not None:
labels = [class_names_map[lbl] if lbl in class_names_map else lbl for lbl in labels]
for idx in range(len(image_list)):
caption_list[idx] = f"GT={labels[idx]}"
if preds is not None: # convert predictions to flat list, if available
if not isinstance(preds, list) and not (isinstance(preds, torch.Tensor) and preds.dim() == 2):
raise AssertionError("expected classification predictions to be in list or 2-d tensor format (BxC)")
if isinstance(preds, list):
if all([isinstance(p, list) for p in preds]):
preds = list(itertools.chain.from_iterable(preds)) # merge all augmented lists together
if all([isinstance(t, torch.Tensor) for t in preds]):
preds = torch.cat(preds, 0)
with torch.no_grad():
preds = torch.squeeze(preds.topk(1, dim=1)[1], dim=1)
if images.shape[0] != preds.shape[0]:
raise AssertionError("images/predictions count mismatch")
preds = preds.tolist()
if class_names_map is not None:
preds = [class_names_map[lbl] if lbl in class_names_map else lbl for lbl in preds]
for idx in range(len(image_list)):
if len(caption_list[idx]) != 0:
caption_list[idx] += ", "
caption_list[idx] += f"Pred={preds[idx]}"
return draw_images(image_list, captions=caption_list, redraw=redraw, window_name="classifs", block=block, **kwargs)
[docs]def draw(task, input, pred=None, target=None, block=False, ch_transpose=True, flip_bgr=False, redraw=None, **kwargs):
"""Draws and returns a figure of a model input/predictions/targets using pyplot or OpenCV."""
# note: this function actually dispatches the drawing procedure using the task interface
import thelper.tasks
if not isinstance(task, thelper.tasks.Task):
raise AssertionError("invalid task object")
if isinstance(input, list) and all([isinstance(t, torch.Tensor) for t in input]):
# if we have a list, it must be due to a augmentation stage
if not all([image.shape == input[0].shape for image in input]):
raise AssertionError("image shape mismatch throughout list")
input = torch.cat(input, 0) # merge all images into a single tensor
if not isinstance(input, torch.Tensor) or input.dim() != 4:
raise AssertionError("expected input images to be in 4-d tensor format (BxCxHxW or BxHxWxC)")
input = input.numpy().copy()
if ch_transpose:
input = np.transpose(input, (0, 2, 3, 1)) # BxCxHxW to BxHxWxC
if flip_bgr:
input = input[..., ::-1] # BGR to RGB
if pred is not None and isinstance(pred, torch.Tensor):
pred = pred.cpu().detach() # avoid latency for preprocessing on gpu
if target is not None and isinstance(target, torch.Tensor):
target = target.cpu().detach() # avoid latency for preprocessing on gpu
if isinstance(task, thelper.tasks.Classification):
class_names_map = {idx: name for name, idx in task.class_indices.items()}
return draw_classifs(images=input, preds=pred, labels=target,
class_names_map=class_names_map, redraw=redraw, block=block, **kwargs)
elif isinstance(task, thelper.tasks.Segmentation):
color_map = task.color_map if task.color_map else {idx: get_label_color_mapping(idx) for idx in task.class_indices.values()}
if task.dontcare is not None and task.dontcare not in color_map:
color_map[task.dontcare] = np.asarray([0, 0, 0])
return draw_segments(images=input, preds=pred, masks=target, color_map=color_map, redraw=redraw, block=block, **kwargs)
elif isinstance(task, thelper.tasks.Detection):
color_map = task.color_map if task.color_map else {idx: get_label_color_mapping(idx) for idx in task.class_indices.values()}
return draw_bboxes(images=input, preds=pred, bboxes=target, color_map=color_map, redraw=redraw, block=block, **kwargs)
elif isinstance(task, thelper.tasks.Regression):
swap_channels = isinstance(task, thelper.tasks.SuperResolution) # must update BxCxHxW to BxHxWxC in targets/preds
# @@@ todo: cleanup swap_channels above via flag in superres task?
return draw_predicts(images=input, preds=pred, targets=target,
swap_channels=swap_channels, redraw=redraw, block=block, **kwargs)
else:
raise AssertionError("unhandled drawing mode, missing impl")
# noinspection PyUnusedLocal
[docs]def draw_errbars(labels, # type: thelper.typedefs.LabelList
min_values, # type: thelper.typedefs.ArrayType
max_values, # type: thelper.typedefs.ArrayType
stddev_values, # type: thelper.typedefs.ArrayType
mean_values, # type: thelper.typedefs.ArrayType
xlabel="", # type: thelper.typedefs.LabelType
ylabel="Raw Value", # type: thelper.typedefs.LabelType
show=False, # type: Optional[bool]
block=False, # type: Optional[bool]
): # type: (...) -> thelper.typedefs.DrawingType
"""Draws and returns an error bar histogram figure using pyplot."""
if min_values.shape != max_values.shape \
or min_values.shape != stddev_values.shape \
or min_values.shape != mean_values.shape:
raise AssertionError("input dim mismatch")
if len(min_values.shape) != 1 and len(min_values.shape) != 2:
raise AssertionError("input dim unexpected")
if len(min_values.shape) == 1:
np.expand_dims(min_values, 1)
np.expand_dims(max_values, 1)
np.expand_dims(stddev_values, 1)
np.expand_dims(mean_values, 1)
nb_subplots = min_values.shape[1]
fig, axs = plt.subplots(nb_subplots)
xrange = range(len(labels))
for ax_idx in range(nb_subplots):
ax = axs[ax_idx]
ax.locator_params(nbins=nb_subplots)
ax.errorbar(xrange, mean_values[:, ax_idx], stddev_values[:, ax_idx], fmt='ok', lw=3)
ax.errorbar(xrange, mean_values[:, ax_idx], [mean_values[:, ax_idx] - min_values[:, ax_idx],
max_values[:, ax_idx] - mean_values[:, ax_idx]],
fmt='.k', ecolor='gray', lw=1)
ax.set_xticks(xrange)
ax.set_xticklabels(labels, visible=(ax_idx == nb_subplots - 1))
ax.set_title("Band %d" % (ax_idx + 1))
ax.tick_params(axis="x", labelsize="6", labelrotation=45)
fig.set_tight_layout(True)
if show:
fig.show()
if block:
plt.show(block=block)
return fig
plt.pause(0.5)
return fig, axs
[docs]def draw_roc_curve(fpr, tpr, labels=None, size_inch=(5, 5), dpi=320, show=False, block=False):
"""Draws and returns an ROC curve figure using pyplot."""
if not isinstance(fpr, np.ndarray) or not isinstance(tpr, np.ndarray):
raise AssertionError("invalid inputs")
if fpr.shape != tpr.shape:
raise AssertionError("mismatched input sizes")
if fpr.ndim == 1:
fpr = np.expand_dims(fpr, 0)
if tpr.ndim == 1:
tpr = np.expand_dims(tpr, 0)
if labels is not None:
if isinstance(labels, str):
labels = [labels]
if len(labels) != fpr.shape[0]:
raise AssertionError("should have one label per curve")
else:
labels = [None] * fpr.shape[0]
fig = plt.figure(num="roc", figsize=size_inch, dpi=dpi, facecolor="w", edgecolor="k")
fig.clf()
ax = fig.add_subplot(1, 1, 1)
for idx, label in enumerate(labels):
auc = sklearn.metrics.auc(fpr[idx, ...], tpr[idx, ...])
if label is not None:
ax.plot(fpr[idx, ...], tpr[idx, ...], "b", label=("%s [auc = %0.3f]" % (label, auc)))
else:
ax.plot(fpr[idx, ...], tpr[idx, ...], "b", label=("auc = %0.3f" % auc))
ax.legend(loc="lower right")
ax.plot([0, 1], [0, 1], 'r--')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_ylabel("True Positive Rate")
ax.set_xlabel("False Positive Rate")
fig.set_tight_layout(True)
if show:
fig.show()
if block:
plt.show(block=block)
return fig
plt.pause(0.5)
return fig, ax
[docs]def draw_confmat(confmat, class_list, size_inch=(5, 5), dpi=320, normalize=False,
keep_unset=False, show=False, block=False):
"""Draws and returns an a confusion matrix figure using pyplot."""
if not isinstance(confmat, np.ndarray) or not isinstance(class_list, list):
raise AssertionError("invalid inputs")
if confmat.ndim != 2:
raise AssertionError("invalid confmat shape")
if not keep_unset and "<unset>" in class_list:
unset_idx = class_list.index("<unset>")
del class_list[unset_idx]
np.delete(confmat, unset_idx, 0)
np.delete(confmat, unset_idx, 1)
if normalize:
row_sums = confmat.sum(axis=1)[:, np.newaxis]
confmat = np.nan_to_num(confmat.astype(np.float) / np.maximum(row_sums, 0.0001))
fig = plt.figure(num="confmat", figsize=size_inch, dpi=dpi, facecolor="w", edgecolor="k")
fig.clf()
ax = fig.add_subplot(1, 1, 1)
ax.imshow(confmat, cmap=plt.cm.Blues, aspect="equal", interpolation="none")
labels = [clipstr(label, 9) for label in class_list]
tick_marks = np.arange(len(labels))
ax.set_xlabel("Predicted", fontsize=7)
ax.set_xticks(tick_marks)
ax.set_xticklabels(labels, fontsize=4, rotation=-90, ha="center")
ax.xaxis.set_label_position("bottom")
ax.xaxis.tick_bottom()
ax.set_ylabel("Real", fontsize=7)
ax.set_yticks(tick_marks)
ax.set_yticklabels(labels, fontsize=4, va="center")
ax.set_ylim(confmat.shape[0] - 0.5, -0.5)
ax.yaxis.set_label_position("left")
ax.yaxis.tick_left()
thresh = confmat.max() / 2.
for i, j in itertools.product(range(confmat.shape[0]), range(confmat.shape[1])):
if not normalize:
txt = ("%d" % confmat[i, j]) if confmat[i, j] != 0 else "."
else:
if confmat[i, j] >= 0.01:
txt = "%.02f" % confmat[i, j]
else:
txt = "~0" if confmat[i, j] > 0 else "."
color = "white" if confmat[i, j] > thresh else "black"
ax.text(j, i, txt, horizontalalignment="center", fontsize=4, verticalalignment="center", color=color)
fig.set_tight_layout(True)
if show:
fig.show()
if block:
plt.show(block=block)
return fig
plt.pause(0.5)
return fig, ax
[docs]def draw_bbox(image, tl, br, text, color, box_thickness=2, font_thickness=1,
font_scale=0.4, show=False, block=False, win_name="bbox"):
"""Draws a single bounding box on a given image (used in :func:`thelper.utils.draw_bboxes`)."""
text_size, baseline = cv.getTextSize(text, fontFace=cv.FONT_HERSHEY_SIMPLEX,
fontScale=font_scale, thickness=font_thickness)
text_bl = (tl[0] + box_thickness + 1, tl[1] + text_size[1] + box_thickness + 1)
# note: text will overflow if box is too small
text_box_br = (text_bl[0] + text_size[0] + box_thickness, text_bl[1] + box_thickness * 2)
cv.rectangle(image, (tl[0] - 1, tl[1] - 1), (text_box_br[0] + 1, text_box_br[1] + 1),
color=(0, 0, 0), thickness=-1)
cv.rectangle(image, tl, br, color=(0, 0, 0), thickness=(box_thickness + 1))
cv.rectangle(image, tl, br, color=color, thickness=box_thickness)
cv.rectangle(image, tl, text_box_br, color=color, thickness=-1)
cv.putText(image, text, text_bl, fontFace=cv.FONT_HERSHEY_SIMPLEX, fontScale=font_scale,
color=(0, 0, 0), thickness=font_thickness + 1)
cv.putText(image, text, text_bl, fontFace=cv.FONT_HERSHEY_SIMPLEX, fontScale=font_scale,
color=(255, 255, 255), thickness=font_thickness)
if show:
cv.imshow(win_name, image)
cv.waitKey(0 if block else 1)
return win_name, image
[docs]def draw_bboxes(images, preds=None, bboxes=None, color_map=None, redraw=None, block=False, min_confidence=0.5, **kwargs):
"""Draws and returns a set of bounding box prediction results."""
image_list = [get_displayable_image(images[batch_idx, ...]) for batch_idx in range(images.shape[0])]
if color_map is not None and isinstance(color_map, dict):
assert len(color_map) <= 256, "too many indices for uint8 map"
color_map_new = np.zeros((256, 3), dtype=np.uint8)
for idx, val in color_map.items():
color_map_new[idx, ...] = val
color_map = color_map_new.tolist()
nb_imgs = len(image_list)
grid_size_x, grid_size_y = nb_imgs, 1 # all images on one row, by default (add gt and preds as extra rows)
box_thickness = thelper.utils.get_key_def("box_thickness", kwargs, default=2, delete=True)
font_thickness = thelper.utils.get_key_def("font_thickness", kwargs, default=1, delete=True)
font_scale = thelper.utils.get_key_def("font_scale", kwargs, default=0.4, delete=True)
if preds is not None:
assert len(image_list) == len(preds)
for preds_list, image in zip(preds, image_list):
for bbox_idx, bbox in enumerate(preds_list):
assert isinstance(bbox, thelper.data.BoundingBox), "unrecognized bbox type"
if bbox.confidence is not None and bbox.confidence < min_confidence:
continue
color = get_bgr_from_hsl(bbox_idx / len(preds_list) * 360, 1.0, 0.5) \
if color_map is None else color_map[bbox.class_id]
conf = ""
if thelper.utils.is_scalar(bbox.confidence):
conf = f" ({bbox.confidence:.3f})"
elif isinstance(bbox.confidence, (list, tuple, np.ndarray)):
conf = f" ({bbox.confidence[bbox.class_id]:.3f})"
draw_bbox(image, bbox.top_left, bbox.bottom_right, f"{bbox.task.class_names[bbox.class_id]}{conf}", color,
box_thickness=box_thickness, font_thickness=font_thickness, font_scale=font_scale)
if bboxes is not None:
assert len(image_list) == len(bboxes), "mismatched bboxes list and image list sizes"
clean_image_list = [get_displayable_image(images[batch_idx, ...]) for batch_idx in range(images.shape[0])]
for bboxes_list, image in zip(bboxes, clean_image_list):
for bbox_idx, bbox in enumerate(bboxes_list):
assert isinstance(bbox, thelper.data.BoundingBox), "unrecognized bbox type"
color = get_bgr_from_hsl(bbox_idx / len(bboxes_list) * 360, 1.0, 0.5) \
if color_map is None else color_map[bbox.class_id]
draw_bbox(image, bbox.top_left, bbox.bottom_right, f"GT: {bbox.task.class_names[bbox.class_id]}", color,
box_thickness=box_thickness, font_thickness=font_thickness, font_scale=font_scale)
grid_size_y += 1
image_list += clean_image_list
return draw_images(image_list, redraw=redraw, window_name="detections", block=block,
grid_size_x=grid_size_x, grid_size_y=grid_size_y, **kwargs)
[docs]def get_label_color_mapping(idx):
"""Returns the PASCAL VOC color triplet for a given label index."""
# https://gist.github.com/wllhf/a4533e0adebe57e3ed06d4b50c8419ae
def bitget(byteval, ch):
return (byteval & (1 << ch)) != 0
r = g = b = 0
for j in range(8):
r = r | (bitget(idx, 0) << 7 - j)
g = g | (bitget(idx, 1) << 7 - j)
b = b | (bitget(idx, 2) << 7 - j)
idx = idx >> 3
return np.array([r, g, b], dtype=np.uint8)
[docs]def apply_color_map(image, colormap, dst=None):
"""Applies a color map to an image of 8-bit color indices; works similarly to cv2.applyColorMap (v3.3.1)."""
if not isinstance(image, np.ndarray) or image.ndim != 2:
raise AssertionError("invalid input image")
if not isinstance(colormap, np.ndarray) or colormap.shape != (256, 1, 3) or (colormap.dtype != np.uint8 and
colormap.dtype != np.float32):
raise AssertionError("invalid color map")
out_shape = (image.shape[0], image.shape[1], 3)
if dst is None:
dst = np.empty(out_shape, dtype=colormap.dtype)
elif not isinstance(dst, np.ndarray) or dst.shape != out_shape or dst.dtype != colormap.dtype:
raise AssertionError("invalid output image")
# using np.take might avoid an extra allocation...
np.copyto(dst, colormap.squeeze()[image.ravel(), :].reshape(out_shape))
return dst
[docs]def stringify_confmat(confmat, class_list, hide_zeroes=False, hide_diagonal=False, hide_threshold=None):
"""Transforms a confusion matrix array obtained in list or numpy format into a printable string."""
if not isinstance(confmat, np.ndarray) or not isinstance(class_list, list):
raise AssertionError("invalid inputs")
column_width = 9
empty_cell = " " * column_width
fst_empty_cell = (column_width - 3) // 2 * " " + "t/p" + (column_width - 3) // 2 * " "
if len(fst_empty_cell) < len(empty_cell):
fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell
res = "\t" + fst_empty_cell + " "
for label in class_list:
res += ("%{0}s".format(column_width) % clipstr(label, column_width)) + " "
res += ("%{0}s".format(column_width) % "total") + "\n"
for idx_true, label in enumerate(class_list):
res += ("\t%{0}s".format(column_width) % clipstr(label, column_width)) + " "
for idx_pred, _ in enumerate(class_list):
cell = "%{0}d".format(column_width) % int(confmat[idx_true, idx_pred])
if hide_zeroes:
cell = cell if int(confmat[idx_true, idx_pred]) != 0 else empty_cell
if hide_diagonal:
cell = cell if idx_true != idx_pred else empty_cell
if hide_threshold:
cell = cell if confmat[idx_true, idx_pred] > hide_threshold else empty_cell
res += cell + " "
res += ("%{0}d".format(column_width) % int(confmat[idx_true, :].sum())) + "\n"
res += ("\t%{0}s".format(column_width) % "total") + " "
for idx_pred, _ in enumerate(class_list):
res += ("%{0}d".format(column_width) % int(confmat[:, idx_pred].sum())) + " "
res += ("%{0}d".format(column_width) % int(confmat.sum())) + "\n"
return res
[docs]def fig2array(fig):
"""Transforms a pyplot figure into a numpy-compatible RGB array."""
fig.canvas.draw()
w, h = fig.canvas.get_width_height()
buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
buf.shape = (w, h, 3)
return buf
[docs]def get_glob_paths(input_glob_pattern, can_be_dir=False):
"""Parse a wildcard-compatible file name pattern for valid file paths."""
glob_file_paths = glob.glob(input_glob_pattern)
if not glob_file_paths:
raise AssertionError("invalid input glob pattern '%s' (no matches found)" % input_glob_pattern)
for file_path in glob_file_paths:
if not os.path.isfile(file_path) and not (can_be_dir and os.path.isdir(file_path)):
raise AssertionError("invalid input file at globed path '%s'" % file_path)
return glob_file_paths
[docs]def get_file_paths(input_path, data_root, allow_glob=False, can_be_dir=False):
"""Parse a wildcard-compatible file name pattern at a given root level for valid file paths."""
if os.path.isabs(input_path):
if '*' in input_path and allow_glob:
return get_glob_paths(input_path)
elif not os.path.isfile(input_path) and not (can_be_dir and os.path.isdir(input_path)):
raise AssertionError("invalid input file at absolute path '%s'" % input_path)
else:
if not os.path.isdir(data_root):
raise AssertionError("invalid dataset root directory at '%s'" % data_root)
input_path = os.path.join(data_root, input_path)
if '*' in input_path and allow_glob:
return get_glob_paths(input_path)
elif not os.path.isfile(input_path) and not (can_be_dir and os.path.isdir(input_path)):
raise AssertionError("invalid input file at path '%s'" % input_path)
return [input_path]