splendor.agents.our_agents.ppo.ppo_rnn package
Subpackages
- splendor.agents.our_agents.ppo.ppo_rnn.gru package
- splendor.agents.our_agents.ppo.ppo_rnn.lstm package
Submodules
splendor.agents.our_agents.ppo.ppo_rnn.recurrent_ppo module
Base class for all PPO which incorporates a recurrent unit in their neural network architecture.
- class splendor.agents.our_agents.ppo.ppo_rnn.recurrent_ppo.RecurrentPPO(input_dim: int, output_dim: int, recurrent_unit: GRU | LSTM | RNN)[source]
Bases:
PPOBase
Base class for all PPO models with recurrent unit.
- abstract forward(x: Float[Tensor, 'batch sequence features'] | Float[Tensor, 'batch features'] | Float[Tensor, 'features'], action_mask: Float[Tensor, 'batch actions'] | Float[Tensor, 'actions'], hidden_state: Any, *args, **kwargs) tuple[Float[Tensor, 'batch actions'], Float[Tensor, 'batch 1'], Float[Tensor, 'batch hidden_dim']] [source]
Pass input through the network to gain predictions.
- Parameters:
x – the input to the network. expected shape: one of the following: (features,) or (batch_size, features) or (batch_size, sequance_length, features).
action_mask – a binary masking tensor, 1’s signals a valid action and 0’s signals an invalid action. expected shape: (actions,) or (batch_size, actions). where actions are equal to len(ALL_ACTIONS) which comes from splendor.Splendor.gym.envs.actions
- Returns:
the actions probabilities, the value estimate and the next hidden state.
return the initial hidden state to be used.