import numpy as np
import torch
import gymnasium as gym
from pathlib import Path
from splendor.template import Agent
from .ppo import PPO, DROPOUT
from splendor.Splendor.features import extract_metrics_with_cards
from splendor.Splendor.gym.envs.utils import (
create_action_mapping,
create_legal_actions_mask,
)
from functools import partial
[docs]
class PPOAgent(Agent):
def __init__(self, _id):
super().__init__(_id)
env = gym.make("splendor-v1", agents=[])
# load_weights
self.net = PPO(
env.observation_space.shape[0], env.action_space.n, dropout=DROPOUT
).double()
checkpoint = torch.load(
str(Path(__file__).parent / f"ppo_model.pth"),
weights_only=False,
map_location="cpu",
)
self.net.load_state_dict(checkpoint["model_state_dict"])
if hasattr(self.net, "input_norm"):
self.net.input_norm.running_mean = checkpoint["running_mean"]
self.net.input_norm.running_var = checkpoint["running_var"]
else:
self.net.running_mean = checkpoint["running_mean"]
self.net.running_var = checkpoint["running_var"]
self.net.eval()
[docs]
def SelectAction(self, actions, game_state, game_rule):
state: np.array = extract_metrics_with_cards(game_state, self.id).astype(
np.float32
)
state_tesnor: torch.Tensor = torch.from_numpy(state).double()
action_mask = torch.from_numpy(
create_legal_actions_mask(actions, game_state, self.id)
).double()
action_pred, value_pred = self.net(state_tesnor, action_mask)
chosen_action = action_pred.argmax()
mapping = create_action_mapping(actions, game_state, self.id)
return mapping[chosen_action.item()]
myAgent = PPOAgent