warp_drive.training.models package¶
Submodules¶
warp_drive.training.models.fully_connected module¶
The Fully Connected Network class
- class warp_drive.training.models.fully_connected.FullyConnected(env, fc_dims, policy, policy_tag_to_agent_id_map, create_separate_placeholders_for_each_policy=False, obs_dim_corresponding_to_num_agents='first')¶
Bases:
torch.nn.modules.module.Module
Fully connected network implementation in Pytorch
- forward(obs=None, batch_index=None, batch_size=None)¶
Forward pass through the model. Returns action probabilities and value functions.
- get_flattened_obs()¶
If the obs is of Box type, it will already be flattened. If the obs is of Dict type, then concatenate all the obs values and flatten them out. Returns the concatenated and flattened obs.
- get_flattened_obs_size()¶
Get the total size of the observations after flattening
- name = 'torch_fully_connected'¶
- push_processed_obs_to_batch(batch_index, batch_size, processed_obs)¶
- reshape_and_flatten_obs(obs)¶
# Note: WarpDrive assumes that all the observation are shaped # (num_agents, *feature_dim), i.e., the observation dimension # corresponding to ‘num_agents’ is the first one. If the observation # dimension corresponding to num_agents is last, we will need to # permute the axes to align with WarpDrive’s assumption.
- training: bool¶
- warp_drive.training.models.fully_connected.apply_logit_mask(logits, mask=None)¶
Mask values of 1 are valid actions. Add huge negative values to logits with 0 mask values.