Source code for splendor.agents.our_agents.ppo.ppo_rnn.recurrent_ppo

from abc import abstractmethod
from typing import Any, Union, Tuple, Callable

import torch
import torch.nn as nn

from jaxtyping import Float

from splendor.agents.our_agents.ppo.ppo_base import PPOBase


[docs] class RecurrentPPO(PPOBase): def __init__( self, input_dim: int, output_dim: int, recurrent_unit: Union[nn.GRU, nn.LSTM, nn.RNN], ): super().__init__(input_dim, output_dim) self.recurrent_unit = recurrent_unit
[docs] @abstractmethod def forward( self, x: Union[ Float[torch.Tensor, "batch sequence features"], Float[torch.Tensor, "batch features"], Float[torch.Tensor, "features"], ], action_mask: Union[ Float[torch.Tensor, "batch actions"], Float[torch.Tensor, "actions"] ], *args, **kwargs, ) -> Tuple[ Float[torch.Tensor, "batch actions"], Float[torch.Tensor, "batch 1"], Float[torch.Tensor, "batch hidden_dim"], ]: """ Pass input through the network to gain predictions. :param x: the input to the network. expected shape: one of the following: (features,) or (batch_size, features) or (batch_size, sequance_length, features). :param action_mask: a binary masking tensor, 1's signals a valid action and 0's signals an invalid action. expected shape: (actions,) or (batch_size, actions). where actions are equal to len(ALL_ACTIONS) which comes from splendor.Splendor.gym.envs.actions :return: the actions probabilities, the value estimate and the next hidden state. """ raise NotImplementedError()
[docs] @abstractmethod def init_hidden_state(self) -> Any: """ return the initial hidden state to be used. """ raise NotImplementedError()
PPOBaseFactory = Callable[[int, int, ...], PPOBase]