Source code for splendor.agents.our_agents.ppo.self_attn.ppo_agent

"""
An agent which uses PPO with self-attention.
"""

from pathlib import Path
from typing import override

import numpy as np
import torch
from numpy.typing import NDArray

from splendor.agents.our_agents.ppo.ppo_agent_base import PPOAgentBase
from splendor.agents.our_agents.ppo.ppo_base import PPOBase
from splendor.agents.our_agents.ppo.utils import load_saved_model
from splendor.splendor.features import extract_metrics_with_cards
from splendor.splendor.gym.envs.utils import (
    create_action_mapping,
    create_legal_actions_mask,
)
from splendor.splendor.splendor_model import SplendorGameRule, SplendorState
from splendor.splendor.types import ActionType

from .network import PPOSelfAttention

DEFAULT_SAVED_PPO_SELF_ATTENTION_PATH = Path(__file__).parent / "ppo_model.pth"


[docs] class PPOSelfAttentionAgent(PPOAgentBase): """ PPO agent with self-attention. """
[docs] @override def SelectAction( self, actions: list[ActionType], game_state: SplendorState, game_rule: SplendorGameRule, ) -> ActionType: """ select an action to play from the given actions. """ with torch.no_grad(): state: NDArray = extract_metrics_with_cards(game_state, self.id).astype( np.float32 ) state_tesnor: torch.Tensor = ( torch.from_numpy(state).double().unsqueeze(0).to(self.device) ) action_mask = ( torch.from_numpy( create_legal_actions_mask(actions, game_state, self.id) ) .double() .to(self.device) ) # this assertion is only for mypy. assert self.net is not None action_pred, *_ = self.net(state_tesnor, action_mask) chosen_action = action_pred.argmax() mapping = create_action_mapping(actions, game_state, self.id) return mapping[chosen_action.item()]
[docs] @override def load(self) -> PPOBase: """ load the weights of the network. """ return load_saved_model(DEFAULT_SAVED_PPO_SELF_ATTENTION_PATH, PPOSelfAttention)
myAgent = PPOSelfAttentionAgent # pylint: disable=invalid-name