warp_drive.training.models package

Submodules

warp_drive.training.models.fully_connected module

The Fully Connected Network class

class warp_drive.training.models.fully_connected.FullyConnected(env, fc_dims, policy, policy_tag_to_agent_id_map, create_separate_placeholders_for_each_policy=False, obs_dim_corresponding_to_num_agents='first')

Bases: torch.nn.modules.module.Module

Fully connected network implementation in Pytorch

forward(obs=None, batch_index=None, batch_size=None)

Forward pass through the model. Returns action probabilities and value functions.

get_flattened_obs()

If the obs is of Box type, it will already be flattened. If the obs is of Dict type, then concatenate all the obs values and flatten them out. Returns the concatenated and flattened obs.

get_flattened_obs_size()

Get the total size of the observations after flattening

name = 'torch_fully_connected'
push_processed_obs_to_batch(batch_index, batch_size, processed_obs)
reshape_and_flatten_obs(obs)

# Note: WarpDrive assumes that all the observation are shaped # (num_agents, *feature_dim), i.e., the observation dimension # corresponding to ‘num_agents’ is the first one. If the observation # dimension corresponding to num_agents is last, we will need to # permute the axes to align with WarpDrive’s assumption.

training: bool
warp_drive.training.models.fully_connected.apply_logit_mask(logits, mask=None)

Mask values of 1 are valid actions. Add huge negative values to logits with 0 mask values.

Module contents