Source code for splendor.agents.our_agents.ppo.ppo_base

"""
Base class for all neural network that should be used by a PPO agent.
"""

from abc import ABC, abstractmethod
from typing import Any, Protocol

import torch
from jaxtyping import Float
from torch import nn


[docs] def unused(x: Any) -> None: # noqa: ANN401 """ Mark the given argument as unused, like casting to void in C. """ _ = x
[docs] class PPOBase(nn.Module, ABC): """ Base class for all neural network that should be used by a PPO agent. """ def __init__( self, input_dim: int, output_dim: int, ) -> None: super().__init__() self.input_dim = input_dim self.output_dim = output_dim
[docs] @abstractmethod def forward( self, x: ( Float[torch.Tensor, "batch sequence features"] | Float[torch.Tensor, "batch features"] | Float[torch.Tensor, " features"] ), action_mask: ( Float[torch.Tensor, "batch actions"] | Float[torch.Tensor, " actions"] ), *args, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor, Any]: """ Pass input through the network to gain predictions. :param x: the input to the network. expected shape: one of the following: (features,) or (batch_size, features) or (batch_size, sequance_length, features). :param action_mask: a binary masking tensor, 1's signals a valid action and 0's signals an invalid action. expected shape: (actions,) or (batch_size, actions). where actions are equal to len(ALL_ACTIONS) which comes from splendor.Splendor.gym.envs.actions :return: the actions probabilities and the value estimate. """ raise NotImplementedError()
[docs] def init_hidden_state(self, device: torch.device) -> Any: # noqa: ANN401 """ return the initial hidden state to be used. """ unused(self) unused(device)
[docs] @staticmethod def create_hidden_layers( input_dim: int, hidden_layers_dims: list[int], dropout: float ) -> nn.Module: """ Create hidden layers based on given dimensions. """ layers: list[nn.Module] = [] prev_dim = input_dim for next_dim in hidden_layers_dims: layers.extend( [ nn.Linear(prev_dim, next_dim), nn.LayerNorm(next_dim), nn.Dropout(dropout), nn.ReLU(), ] ) prev_dim = next_dim return nn.Sequential(*layers)
[docs] class PPOBaseFactory(Protocol): """ factory for PPO models. """ # pylint: disable=too-few-public-methods def __call__(self, input_dim: int, output_dim: int, *args, **kwargs) -> PPOBase: pass