warp_drive.training package¶
Subpackages¶
- warp_drive.training.algorithms package
- warp_drive.training.models package
- warp_drive.training.utils package
- Subpackages
- Submodules
- warp_drive.training.utils.child_process_base module
- warp_drive.training.utils.data_loader module
- warp_drive.training.utils.param_scheduler module
- warp_drive.training.utils.process_group_torch module
- warp_drive.training.utils.vertical_scaler module
- Module contents
Submodules¶
warp_drive.training.example_training_script_numba module¶
Example training script for the grid world and continuous versions of Tag.
- warp_drive.training.example_training_script_numba.setup_trainer_and_train(run_configuration, device_id=0, num_devices=1, event_messenger=None, results_directory=None, verbose=True)¶
Create the environment wrapper, define the policy mapping to agent ids, and create the trainer object. Also, perform training.
warp_drive.training.example_training_script_pycuda module¶
Example training script for the grid world and continuous versions of Tag.
- warp_drive.training.example_training_script_pycuda.setup_trainer_and_train(run_configuration, device_id=0, num_devices=1, event_messenger=None, results_directory=None, verbose=True)¶
Create the environment wrapper, define the policy mapping to agent ids, and create the trainer object. Also, perform training.
warp_drive.training.pytorch_lightning module¶
warp_drive.training.trainer module¶
The Trainer, PerfStats and Metrics classes
- class warp_drive.training.trainer.Metrics¶
Bases:
object
Metrics class to log and print the key metrics
- pretty_print(metrics)¶
- class warp_drive.training.trainer.PerfStats¶
Bases:
object
Performance stats that will be included in rollout metrics.
- get_perf_stats()¶
- static pretty_print(stats)¶
- class warp_drive.training.trainer.Trainer(env_wrapper=None, config=None, policy_tag_to_agent_id_map=None, create_separate_placeholders_for_each_policy=False, obs_dim_corresponding_to_num_agents='first', num_devices=1, device_id=0, results_dir=None, verbose=True)¶
Bases:
object
The trainer object. Contains modules train(), save_model_checkpoint() and fetch_episode_global_states()
- fetch_episode_states(list_of_states=None, env_id=0)¶
Step through env and fetch the desired states (data arrays on the GPU) for an entire episode. The trained models will be used for evaluation.
- graceful_close()¶
- load_model_checkpoint(ckpts_dict=None)¶
Load the model parameters if a checkpoint path is specified.
- save_model_checkpoint(iteration=0)¶
Save the model parameters
- train()¶
Perform training.
- warp_drive.training.trainer.all_equal(iterable)¶
Check all elements of an iterable (e.g., list) are identical
- warp_drive.training.trainer.recursive_merge_config_dicts(config, default_config)¶
Merge the configuration dictionary with the default configuration dictionary to fill in any missing configuration keys.
- warp_drive.training.trainer.verbose_print(message, device_id=None)¶