"""
Command-line module, for use with a ``__main__`` entrypoint.
This module contains the primary functions used to create or resume a training session, to start a
visualization session, or to start an annotation session. The basic argument that needs to be provided
by the user to create any kind of session is a configuration dictionary. For sessions that produce
outputs, the path to a directory where to save the data is also needed.
"""
import argparse
import json
import logging
import os
from typing import Any, Union
import torch
import tqdm
import thelper
[docs]def create_session(config, save_dir):
"""Creates a session to train a model.
All generated outputs (model checkpoints and logs) will be saved in a directory named after the
session (the name itself is specified in ``config``), and located in ``save_dir``.
Args:
config: a dictionary that provides all required data configuration and trainer parameters; see
:class:`thelper.train.base.Trainer` and :func:`thelper.data.utils.create_loaders` for more information.
Here, it is only expected to contain a ``name`` field that specifies the name of the session.
save_dir: the path to the root directory where the session directory should be saved. Note that
this is not the path to the session directory itself, but its parent, which may also contain
other session directories.
.. seealso::
| :class:`thelper.train.base.Trainer`
"""
logger = thelper.utils.get_func_logger()
if "name" not in config or not config["name"]:
raise AssertionError("config missing 'name' field")
session_name = config["name"]
logger.info("creating new training session '%s'..." % session_name)
thelper.utils.setup_globals(config)
save_dir = thelper.utils.get_save_dir(save_dir, session_name, config)
logger.debug("session will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/"))
task, train_loader, valid_loader, test_loader = thelper.data.create_loaders(config, save_dir)
model = thelper.nn.create_model(config, task, save_dir=save_dir)
loaders = (train_loader, valid_loader, test_loader)
trainer = thelper.train.create_trainer(session_name, save_dir, config, model, task, loaders)
logger.debug("starting trainer")
if train_loader:
trainer.train()
else:
trainer.eval()
logger.debug("all done")
return trainer.outputs
[docs]def resume_session(ckptdata, save_dir, config=None, eval_only=False):
"""Resumes a previously created training session.
Since the saved checkpoints contain the original session's configuration, the ``config`` argument
can be set to ``None`` if the session should simply pick up where it was interrupted. Otherwise,
the ``config`` argument can be set to a new configuration that will override the older one. This is
useful when fine-tuning a model, or when testing on a new dataset.
.. warning::
If a session is resumed with an overriding configuration, the user must make sure that the
inputs/outputs of the older model are compatible with the new parameters. For example, with
classifiers, this means that the number of classes parsed by the dataset (and thus to be
predicted by the model) should remain the same. This is a limitation of the framework that
should be addressed in a future update.
.. warning::
A resumed session will not be compatible with its original RNG states if the number of workers
used is changed. To get 100% reproducible results, make sure you run with the same worker count.
Args:
ckptdata: raw checkpoint data loaded via ``torch.load()``; it will be parsed by the various
parts of the framework that need to reload their previous state.
save_dir: the path to the root directory where the session directory should be saved. Note that
this is not the path to the session directory itself, but its parent, which may also contain
other session directories.
config: a dictionary that provides all required data configuration and trainer parameters; see
:class:`thelper.train.base.Trainer` and :func:`thelper.data.utils.create_loaders` for more information.
Here, it is only expected to contain a ``name`` field that specifies the name of the session.
eval_only: specifies whether training should be resumed or the model should only be evaluated.
.. seealso::
| :class:`thelper.train.base.Trainer`
"""
logger = thelper.utils.get_func_logger()
if ckptdata is None or not ckptdata:
raise AssertionError("must provide valid checkpoint data to resume a session!")
if not config:
if "config" not in ckptdata or not ckptdata["config"]:
raise AssertionError("checkpoint data missing 'config' field")
config = ckptdata["config"]
if "name" not in config or not config["name"]:
raise AssertionError("config missing 'name' field")
session_name = config["name"]
logger.info("loading training session '%s' objects..." % session_name)
thelper.utils.setup_globals(config)
save_dir = thelper.utils.get_save_dir(save_dir, session_name, config, resume=True)
logger.debug("session will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/"))
new_task, train_loader, valid_loader, test_loader = thelper.data.create_loaders(config, save_dir)
if "task" not in ckptdata or not ckptdata["task"] or not isinstance(ckptdata["task"], (thelper.tasks.Task, str)):
raise AssertionError("invalid checkpoint, cannot reload previous model task")
old_task = thelper.tasks.create_task(ckptdata["task"]) if isinstance(ckptdata["task"], str) else ckptdata["task"]
if not old_task.check_compat(new_task, exact=True):
compat_task = None if not old_task.check_compat(new_task) else old_task.get_compat(new_task)
choice = thelper.utils.query_string("Found discrepancy between old task from checkpoint and new task from config; "+
"which one would you like to resume the session with?\n" +
f"\told: {str(old_task)}\n\tnew: {str(new_task)}\n" +
(f"\tcompat: {str(compat_task)}\n\n" if compat_task is not None else "\n") +
"WARNING: if resuming with new or compat, some weights might be discarded!",
choices=["old", "new", "compat"])
task = old_task if choice == "old" else new_task if choice == "new" else compat_task
if choice != "old":
# saved optimizer state might cause issues with mismatched tasks, let's get rid of it
logger.warning("dumping optimizer state to avoid issues when resuming with modified task")
ckptdata["optimizer"], ckptdata["scheduler"] = None, None
else:
task = new_task
assert task is not None, "invalid task"
model = thelper.nn.create_model(config, task, save_dir=save_dir, ckptdata=ckptdata)
loaders = (None if eval_only else train_loader, valid_loader, test_loader)
trainer = thelper.train.create_trainer(session_name, save_dir, config, model, task, loaders, ckptdata=ckptdata)
if eval_only:
logger.info("evaluating session '%s' checkpoint @ epoch %d" % (trainer.name, trainer.current_epoch))
trainer.eval()
else:
logger.info("resuming training session '%s' @ epoch %d" % (trainer.name, trainer.current_epoch))
trainer.train()
logger.debug("all done")
return trainer.outputs
[docs]def visualize_data(config):
"""Displays the images used in a training session.
This mode does not generate any output, and is only used to visualize the (transformed) images used
in a training session. This is useful to debug the data augmentation and base transformation pipelines
and make sure the modified images are valid. It does not attempt to load a model or instantiate a
trainer, meaning the related fields are not required inside ``config``.
If the configuration dictionary includes a 'loaders' field, it will be parsed and used. Otherwise,
if only a 'datasets' field is available, basic loaders will be instantiated to load the data. The
'loaders' field can also be ignored if 'ignore_loaders' is found within the 'viz' section of the config
and set to ``True``. Each minibatch will be displayed via pyplot or OpenCV. The display will block and
wait for user input, unless 'block' is set within the 'viz' section's 'kwargs' config as ``False``.
Args:
config: a dictionary that provides all required data configuration parameters; see
:func:`thelper.data.utils.create_loaders` for more information.
.. seealso::
| :func:`thelper.data.utils.create_loaders`
| :func:`thelper.data.utils.create_parsers`
"""
logger = thelper.utils.get_func_logger()
logger.info("creating visualization session...")
thelper.utils.setup_globals(config)
viz_config = thelper.utils.get_key_def("viz", config, default={})
if not isinstance(viz_config, dict):
raise AssertionError("unexpected viz config type")
ignore_loaders = thelper.utils.get_key_def("ignore_loaders", viz_config, default=False)
viz_kwargs = thelper.utils.get_key_def("kwargs", viz_config, default={})
if not isinstance(viz_kwargs, dict):
raise AssertionError("unexpected viz kwargs type")
if thelper.utils.get_key_def(["data_config", "loaders"], config, default=None) is None or ignore_loaders:
datasets, task = thelper.data.create_parsers(config)
loader_map = {dataset_name: thelper.data.DataLoader(dataset,) for dataset_name, dataset in datasets.items()}
# we assume no transforms were done in the parser, and images are given as read by opencv
viz_kwargs["ch_transpose"] = thelper.utils.get_key_def("ch_transpose", viz_kwargs, False)
viz_kwargs["flip_bgr"] = thelper.utils.get_key_def("flip_bgr", viz_kwargs, False)
else:
task, train_loader, valid_loader, test_loader = thelper.data.create_loaders(config)
loader_map = {"train": train_loader, "valid": valid_loader, "test": test_loader}
# we assume transforms were done in the loader, and images are given as expected by pytorch
viz_kwargs["ch_transpose"] = thelper.utils.get_key_def("ch_transpose", viz_kwargs, True)
viz_kwargs["flip_bgr"] = thelper.utils.get_key_def("flip_bgr", viz_kwargs, False)
redraw = None
viz_kwargs["block"] = thelper.utils.get_key_def("block", viz_kwargs, default=True)
assert "quit" not in loader_map
choices = list(loader_map.keys()) + ["quit"]
while True:
choice = thelper.utils.query_string("Which loader would you like to visualize?", choices=choices)
if choice == "quit":
break
loader = loader_map[choice]
if loader is None:
logger.info("loader '%s' is empty" % choice)
continue
batch_count = len(loader)
logger.info("initializing loader '%s' with %d batches..." % (choice, batch_count))
for sample in tqdm.tqdm(loader):
input = sample[task.input_key]
target = sample[task.gt_key] if task.gt_key in sample else None
redraw = thelper.utils.draw(task=task, input=input, target=target, redraw=redraw, **viz_kwargs)
logger.info("all done")
[docs]def annotate_data(config, save_dir):
"""Launches an annotation session for a dataset using a specialized GUI tool.
Note that the annotation type must be supported by the GUI tool. The annotations created by the user
during the session will be saved in the session directory.
Args:
config: a dictionary that provides all required dataset and GUI tool configuration parameters; see
:func:`thelper.data.utils.create_parsers` and :func:`thelper.gui.utils.create_annotator` for more
information.
save_dir: the path to the root directory where the session directory should be saved. Note that
this is not the path to the session directory itself, but its parent, which may also contain
other session directories.
.. seealso::
| :func:`thelper.gui.annotators.Annotator`
| :func:`thelper.gui.annotators.ImageSegmentAnnotator`
"""
# import gui here since it imports packages that will cause errors in CLI-only environments
import thelper.gui
logger = thelper.utils.get_func_logger()
if "name" not in config or not config["name"]:
raise AssertionError("config missing 'name' field")
session_name = config["name"]
logger.info("creating annotation session '%s'..." % session_name)
thelper.utils.setup_globals(config)
save_dir = thelper.utils.get_save_dir(save_dir, session_name, config)
logger.debug("session will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/"))
datasets, _ = thelper.data.create_parsers(config)
annotator = thelper.gui.create_annotator(session_name, save_dir, config, datasets)
logger.debug("starting annotator")
annotator.run()
logger.debug("all done")
[docs]def split_data(config, save_dir):
"""Launches a dataset splitting session.
This mode will generate an HDF5 archive that contains the split datasets defined in the session
configuration file. This archive can then be reused in a new training session to guarantee a fixed
distribution of training, validation, and testing samples. It can also be used outside the framework
in order to reproduce an experiment.
The configuration dictionary must minimally contain two sections: 'datasets' and 'loaders'. A third
section, 'split', can be used to provide settings regarding the archive packing and compression
approaches to use.
The HDF5 archive will be saved in the session's output directory.
Args:
config: a dictionary that provides all required data configuration parameters; see
:func:`thelper.data.utils.create_loaders` for more information.
save_dir: the path to the root directory where the session directory should be saved. Note that
this is not the path to the session directory itself, but its parent, which may also contain
other session directories.
.. seealso::
| :func:`thelper.data.utils.create_loaders`
| :func:`thelper.data.utils.create_hdf5`
| :class:`thelper.data.parsers.HDF5Dataset`
"""
logger = thelper.utils.get_func_logger()
if "name" not in config or not config["name"]:
raise AssertionError("config missing 'name' field")
session_name = config["name"]
split_config = thelper.utils.get_key_def("split", config, default={})
if not isinstance(split_config, dict):
raise AssertionError("unexpected split config type")
compression = thelper.utils.get_key_def("compression", split_config, default={})
if not isinstance(compression, dict):
raise AssertionError("compression params should be given as dictionary")
archive_name = thelper.utils.get_key_def("archive_name", split_config, default=(session_name + ".hdf5"))
logger.info("creating new splitting session '%s'..." % session_name)
thelper.utils.setup_globals(config)
save_dir = thelper.utils.get_save_dir(save_dir, session_name, config)
logger.debug("session will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/"))
task, train_loader, valid_loader, test_loader = thelper.data.create_loaders(config, save_dir)
archive_path = os.path.join(save_dir, archive_name)
thelper.data.create_hdf5(archive_path, task, train_loader, valid_loader, test_loader, compression, config)
logger.debug("all done")
[docs]def export_model(config, save_dir):
"""Launches a model exportation session.
This function will export a model defined via a configuration file into a new checkpoint that can be
loaded elsewhere. The model can be built using the framework, or provided via its type, construction
parameters, and weights. Its exported format will be compatible with the framework, and may also be an
optimized/compiled version obtained using PyTorch's JIT tracer.
The configuration dictionary must minimally contain a 'model' section that provides details on the model
to be exported. A section named 'export' can be used to provide settings regarding the exportation
approaches to use, and the task interface to save with the model. If a task is not explicitly defined in
the 'export' section, the session configuration will be parsed for a 'datasets' section that can be used
to define it. Otherwise, it must be provided through the model.
The exported checkpoint containing the model will be saved in the session's output directory.
Args:
config: a dictionary that provides all required data configuration parameters; see
:func:`thelper.nn.utils.create_model` for more information.
save_dir: the path to the root directory where the session directory should be saved. Note that
this is not the path to the session directory itself, but its parent, which may also contain
other session directories.
.. seealso::
| :func:`thelper.nn.utils.create_model`
"""
logger = thelper.utils.get_func_logger()
if "name" not in config or not config["name"]:
raise AssertionError("config missing 'name' field")
session_name = config["name"]
export_config = thelper.utils.get_key_def("export", config, default={})
if not isinstance(export_config, dict):
raise AssertionError("unexpected export config type")
ckpt_name = thelper.utils.get_key_def("ckpt_name", export_config, default=(session_name + ".export.pth"))
trace_name = thelper.utils.get_key_def("trace_name", export_config, default=(session_name + ".trace.zip"))
save_raw = thelper.utils.get_key_def("save_raw", export_config, default=True)
trace_input = thelper.utils.get_key_def("trace_input", export_config, default=None)
task = thelper.utils.get_key_def("task", export_config, default=None)
if isinstance(task, (str, dict)):
task = thelper.tasks.create_task(task)
if task is None and "datasets" in config:
_, task = thelper.data.create_parsers(config) # try to load via datasets...
if isinstance(trace_input, str):
trace_input = eval(trace_input)
logger.info("exporting model '%s'..." % session_name)
thelper.utils.setup_globals(config)
save_dir = thelper.utils.get_save_dir(save_dir, session_name, config)
logger.debug("exported checkpoint will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/"))
model = thelper.nn.create_model(config, task, save_dir=save_dir)
if task is None:
assert hasattr(model, "task"), "model should have task attrib if not provided already"
task = model.task
log_stamp = thelper.utils.get_log_stamp()
model_type = model.get_name()
model_params = model.config if model.config else {}
# the saved state below should be kept compatible with the one in thelper.train.base._save
export_state = {
"name": session_name,
"source": log_stamp,
"git_sha1": thelper.utils.get_git_stamp(),
"version": thelper.__version__,
"task": str(task) if save_raw else task,
"model_type": model_type,
"model_params": model_params,
"config": config
}
if trace_input is not None:
trace_path = os.path.join(save_dir, trace_name)
torch.jit.trace(model, trace_input).save(trace_path)
export_state["model"] = trace_name # will be loaded in thelper.utils.load_checkpoint
else:
export_state["model"] = model.state_dict() if save_raw else model
torch.save(export_state, os.path.join(save_dir, ckpt_name))
logger.debug("all done")
[docs]def make_argparser():
# type: () -> argparse.ArgumentParser
"""Creates the (default) argument parser to use for the main entrypoint.
The argument parser will contain different "operating modes" that dictate the high-level behavior of the CLI. This
function may be modified in branches of the framework to add project-specific features.
"""
ap = argparse.ArgumentParser(description='thelper model trainer application')
ap.add_argument("--version", default=False, action="store_true", help="prints the version of the library and exits")
ap.add_argument("-l", "--log", default=None, type=str, help="path to the top-level log file (default: None)")
ap.add_argument("-v", "--verbose", action="count", default=0, help="set logging terminal verbosity level (additive)")
ap.add_argument("--silent", action="store_true", default=False, help="deactivates all console logging activities")
ap.add_argument("--force-stdout", action="store_true", default=False, help="force logging output to stdout instead of stderr")
subparsers = ap.add_subparsers(title="Operating mode", dest="mode")
new_ap = subparsers.add_parser("new", help="creates a new session from a config file")
new_ap.add_argument("cfg_path", type=str, help="path to the session configuration file")
new_ap.add_argument("save_dir", type=str, help="path to the root directory where checkpoints should be saved")
cl_new_ap = subparsers.add_parser("cl_new", help="creates a new session from a config file for the cluster")
cl_new_ap.add_argument("cfg_path", type=str, help="path to the session configuration file")
cl_new_ap.add_argument("save_dir", type=str, help="path to the root directory where checkpoints should be saved")
resume_ap = subparsers.add_parser("resume", help="resume a session from a checkpoint file")
resume_ap.add_argument("ckpt_path", type=str, help="path to the checkpoint (or save directory) to resume training from")
resume_ap.add_argument("-s", "--save-dir", default=None, type=str, help="path to the root directory where checkpoints should be saved")
resume_ap.add_argument("-m", "--map-location", default=None, help="map location for loading data (default=None)")
resume_ap.add_argument("-c", "--override-cfg", default=None, help="override config file path (default=None)")
resume_ap.add_argument("-e", "--eval-only", default=False, action="store_true", help="only run evaluation pass (valid+test)")
viz_ap = subparsers.add_parser("viz", help="visualize the loaded data for a training/eval session")
viz_ap.add_argument("cfg_path", type=str, help="path to the session configuration file (or session save directory)")
annot_ap = subparsers.add_parser("annot", help="launches a dataset annotation session with a GUI tool")
annot_ap.add_argument("cfg_path", type=str, help="path to the session configuration file (or session save directory)")
annot_ap.add_argument("save_dir", type=str, help="path to the root directory where annotations should be saved")
split_ap = subparsers.add_parser("split", help="launches a dataset splitting session from a config file")
split_ap.add_argument("cfg_path", type=str, help="path to the session configuration file (or session save directory)")
split_ap.add_argument("save_dir", type=str, help="path to the root directory where the split hdf5 dataset archive should be saved")
split_ap = subparsers.add_parser("export", help="launches a model exportation session from a config file")
split_ap.add_argument("cfg_path", type=str, help="path to the session configuration file (or session save directory)")
split_ap.add_argument("save_dir", type=str, help="path to the root directory where the exported checkpoint should be saved")
return ap
[docs]def setup(args=None, argparser=None):
# type: (Any, argparse.Namespace) -> Union[int, argparse.Namespace]
"""Sets up the argument parser (if not already done externally) and parses the input CLI arguments.
This function may return an error code (integer) if the program should exit immediately. Otherwise, it will return
the parsed arguments to use in order to redirect the execution flow of the entrypoint.
"""
argparser = argparser or make_argparser()
args = argparser.parse_args(args=args)
if args.version:
print(thelper.__version__)
return 0
if args.mode is None:
argparser.print_help()
return 1
if args.silent and args.verbose > 0:
raise AssertionError("contradicting verbose/silent arguments provided")
log_level = logging.INFO if args.verbose < 1 else logging.DEBUG if args.verbose < 2 else logging.NOTSET
thelper.utils.init_logger(log_level, args.log, args.force_stdout)
return args
[docs]def main(args=None, argparser=None):
"""Main entrypoint to use with console applications.
This function parses command line arguments and dispatches the execution based on the selected
operating mode. Run with ``--help`` for information on the available arguments.
.. warning::
If you are trying to resume a session that was previously executed using a now unavailable GPU,
you will have to force the checkpoint data to be loaded on CPU using ``--map-location=cpu`` (or
using ``-m=cpu``).
.. seealso::
| :func:`thelper.cli.create_session`
| :func:`thelper.cli.resume_session`
| :func:`thelper.cli.visualize_data`
| :func:`thelper.cli.annotate_data`
| :func:`thelper.cli.split_data`
"""
args = setup(args=args, argparser=argparser)
if isinstance(args, int):
return args # CLI must exit immediately with provided error code
if args.mode == "new" or args.mode == "cl_new":
thelper.logger.debug("parsing config at '%s'" % args.cfg_path)
with open(args.cfg_path) as fd:
config = json.load(fd)
if args.mode == "cl_new":
trainer_config = thelper.utils.get_key_def("trainer", config, {})
device = thelper.utils.get_key_def("device", trainer_config, None)
if device is not None:
raise AssertionError("cannot specify device in config for cluster sessions, it is determined at runtime")
create_session(config, args.save_dir)
elif args.mode == "resume":
ckptdata = thelper.utils.load_checkpoint(args.ckpt_path, map_location=args.map_location,
always_load_latest=(not args.eval_only))
override_config = None
if args.override_cfg:
thelper.logger.debug("parsing override config at '%s'" % args.override_cfg)
with open(args.override_cfg) as fd:
override_config = json.load(fd)
save_dir = args.save_dir
if save_dir is None:
ckpt_dir_path = os.path.dirname(os.path.abspath(args.ckpt_path)) \
if not os.path.isdir(args.ckpt_path) else os.path.abspath(args.ckpt_path)
# find session dir by looking for 'logs' directory
if os.path.isdir(os.path.join(ckpt_dir_path, "logs")):
save_dir = os.path.abspath(os.path.join(ckpt_dir_path, ".."))
elif os.path.isdir(os.path.join(ckpt_dir_path, "../logs")):
save_dir = os.path.abspath(os.path.join(ckpt_dir_path, "../.."))
else:
save_dir = thelper.utils.query_string("Please provide the path to where the resumed session output should be saved:")
save_dir = thelper.utils.get_save_dir(save_dir, dir_name="", config=override_config)
resume_session(ckptdata, save_dir, config=override_config, eval_only=args.eval_only)
else:
thelper.logger.debug("parsing config at '%s'" % args.cfg_path)
with open(args.cfg_path) as fd:
config = json.load(fd)
if args.mode == "viz":
visualize_data(config)
elif args.mode == "annot":
annotate_data(config, args.save_dir)
elif args.mode == "export":
export_model(config, args.save_dir)
else: # if args.mode == "split":
split_data(config, args.save_dir)
return 0
if __name__ == "__main__":
main()