Source code for splendor.agents.our_agents.ppo.ppo_agent_base

"""
Definition for a base class for all PPO-based agents.
"""

from abc import abstractmethod

import torch
from torch import nn

from splendor.splendor.splendor_model import SplendorGameRule, SplendorState
from splendor.splendor.types import ActionType
from splendor.template import Agent

from .ppo_base import PPOBase


[docs] class PPOAgentBase(Agent): """ base class for all PPO-based agents. """ def __init__(self, _id: int, load_net: bool = True) -> None: super().__init__(_id) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.net: nn.Module | None = None if load_net: self.load_policy(self.load())
[docs] @abstractmethod def SelectAction( self, actions: list[ActionType], game_state: SplendorState, game_rule: SplendorGameRule, ) -> ActionType: """ select an action to play from the given actions. """ raise NotImplementedError()
[docs] @abstractmethod def load(self) -> PPOBase: """ load and return the weights of the network. """ raise NotImplementedError()
[docs] def load_policy(self, policy: nn.Module) -> None: """ Use a given policy as the agent's network policy. """ self.net = policy.to(self.device) self.net.eval()