Source code for splendor.agents.our_agents.ppo.ppo_agent_base

from pathlib import Path
from abc import abstractmethod
from typing import List

import numpy as np
import torch
import gymnasium as gym

from splendor.template import Agent
from splendor.Splendor.splendor_model import SplendorState, SplendorGameRule
from splendor.Splendor.types import ActionType

from .ppo_base import PPOBase


[docs] class PPOAgentBase(Agent): """ base class for all PPO-based agents. """ def __init__(self, _id): super().__init__(_id) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.net = self.load().to(self.device) self.net.eval()
[docs] @abstractmethod def SelectAction( self, actions: List[ActionType], game_state: SplendorState, game_rule: SplendorGameRule, ) -> ActionType: """ select an action to play from the given actions. """ raise NotImplementedError()
[docs] @abstractmethod def load(self) -> PPOBase: """ load the weights of the network. """ raise NotImplementedError()