warp_drive.training package
Subpackages
Submodules
warp_drive.training.example_training_script module
Example training script for the grid world and continuous versions of Tag.
- warp_drive.training.example_training_script.setup_trainer_and_train(run_configuration, device_id=0, num_devices=1, event_messenger=None, results_directory=None, verbose=True)
Create the environment wrapper, define the policy mapping to agent ids, and create the trainer object. Also, perform training.
warp_drive.training.trainer module
The Trainer, PerfStats and Metrics classes
- class warp_drive.training.trainer.Metrics
Bases:
object
Metrics class to log and print the key metrics
- pretty_print(metrics)
- class warp_drive.training.trainer.PerfStats
Bases:
object
Performance stats that will be included in rollout metrics.
- get_perf_stats()
- static pretty_print(stats)
- class warp_drive.training.trainer.Trainer(env_wrapper=None, config=None, policy_tag_to_agent_id_map=None, create_separate_placeholders_for_each_policy=False, obs_dim_corresponding_to_num_agents='first', num_devices=1, device_id=0, results_dir=None, verbose=True)
Bases:
object
The trainer object. Contains modules train(), save_model_checkpoint() and fetch_episode_global_states()
- fetch_episode_states(list_of_states=None, env_id=0)
Step through env and fetch the desired states (data arrays on the GPU) for an entire episode. The trained models will be used for evaluation.
- graceful_close()
- load_model_checkpoint(ckpts_dict=None)
Load the model parameters if a checkpoint path is specified.
- save_model_checkpoint(iteration=0)
Save the model parameters
- train()
Perform training.
- warp_drive.training.trainer.all_equal(iterable)
Check all elements of an iterable (e.g., list) are identical
- warp_drive.training.trainer.recursive_merge_config_dicts(config, default_config)
Merge the configuration dictionary with the default configuration dictionary to fill in any missing configuration keys.
- warp_drive.training.trainer.verbose_print(message, device_id=None)