warp_drive.training package

Submodules

warp_drive.training.example_training_script_numba module

Example training script for the grid world and continuous versions of Tag.

warp_drive.training.example_training_script_numba.setup_trainer_and_train(run_configuration, device_id=0, num_devices=1, event_messenger=None, results_directory=None, verbose=True)

Create the environment wrapper, define the policy mapping to agent ids, and create the trainer object. Also, perform training.

warp_drive.training.example_training_script_pycuda module

Example training script for the grid world and continuous versions of Tag.

warp_drive.training.example_training_script_pycuda.setup_trainer_and_train(run_configuration, device_id=0, num_devices=1, event_messenger=None, results_directory=None, verbose=True)

Create the environment wrapper, define the policy mapping to agent ids, and create the trainer object. Also, perform training.

warp_drive.training.pytorch_lightning module

warp_drive.training.trainer module

The Trainer, PerfStats and Metrics classes

class warp_drive.training.trainer.Metrics

Bases: object

Metrics class to log and print the key metrics

pretty_print(metrics)
class warp_drive.training.trainer.PerfStats

Bases: object

Performance stats that will be included in rollout metrics.

get_perf_stats()
static pretty_print(stats)
class warp_drive.training.trainer.Trainer(env_wrapper=None, config=None, policy_tag_to_agent_id_map=None, create_separate_placeholders_for_each_policy=False, obs_dim_corresponding_to_num_agents='first', num_devices=1, device_id=0, results_dir=None, verbose=True)

Bases: object

The trainer object. Contains modules train(), save_model_checkpoint() and fetch_episode_global_states()

fetch_episode_states(list_of_states=None, env_id=0)

Step through env and fetch the desired states (data arrays on the GPU) for an entire episode. The trained models will be used for evaluation.

graceful_close()
load_model_checkpoint(ckpts_dict=None)

Load the model parameters if a checkpoint path is specified.

save_model_checkpoint(iteration=0)

Save the model parameters

train()

Perform training.

warp_drive.training.trainer.all_equal(iterable)

Check all elements of an iterable (e.g., list) are identical

warp_drive.training.trainer.recursive_merge_config_dicts(config, default_config)

Merge the configuration dictionary with the default configuration dictionary to fill in any missing configuration keys.

warp_drive.training.trainer.verbose_print(message, device_id=None)

Module contents