splendor.agents.our_agents.ppo.ppo_rnn package

Subpackages

Submodules

splendor.agents.our_agents.ppo.ppo_rnn.recurrent_ppo module

Base class for all PPO which incorporates a recurrent unit in their neural network architecture.

class splendor.agents.our_agents.ppo.ppo_rnn.recurrent_ppo.RecurrentPPO(input_dim: int, output_dim: int, recurrent_unit: GRU | LSTM | RNN)[source]

Bases: PPOBase

Base class for all PPO models with recurrent unit.

abstract forward(x: Float[Tensor, 'batch sequence features'] | Float[Tensor, 'batch features'] | Float[Tensor, 'features'], action_mask: Float[Tensor, 'batch actions'] | Float[Tensor, 'actions'], hidden_state: Any, *args, **kwargs) tuple[Float[Tensor, 'batch actions'], Float[Tensor, 'batch 1'], Float[Tensor, 'batch hidden_dim']][source]

Pass input through the network to gain predictions.

Parameters:
  • x – the input to the network. expected shape: one of the following: (features,) or (batch_size, features) or (batch_size, sequance_length, features).

  • action_mask – a binary masking tensor, 1’s signals a valid action and 0’s signals an invalid action. expected shape: (actions,) or (batch_size, actions). where actions are equal to len(ALL_ACTIONS) which comes from splendor.Splendor.gym.envs.actions

Returns:

the actions probabilities, the value estimate and the next hidden state.

abstract init_hidden_state(device: device) tuple[Any, Any][source]

return the initial hidden state to be used.

Module contents