Source code for splendor.agents.our_agents.ppo.utils

from typing import Optional
from pathlib import Path
from functools import cache

import torch
import gymnasium as gym

from gymnasium.spaces.utils import flatdim

from .ppo_base import PPOBase, PPOBaseFactory
from .network import PPO, DROPOUT


DEFAULT_SAVED_PPO_PATH = Path(__file__).parent / "ppo_model.pth"


[docs] @cache def load_saved_model( path: Path, ppo_factory: PPOBaseFactory, *args, **kwargs, ) -> PPOBase: """ Load saved weights of a PPO model from a given path, if no path was given the installed weights of the PPO agent will be loaded. """ env = gym.make("splendor-v1", agents=[]) # load_weights net = ppo_factory( flatdim(env.observation_space), flatdim(env.action_space), *args, **kwargs ).double() checkpoint = torch.load( str(path), weights_only=False, map_location="cpu", ) net.load_state_dict(checkpoint["model_state_dict"]) if hasattr(net, "input_norm"): # both running_mean & running_var are stored as (1, flatdim(env.observation_space)) # rather than (flatdim(env.observation_space),) net.input_norm.running_mean = checkpoint["running_mean"].squeeze(0) net.input_norm.running_var = checkpoint["running_var"].squeeze(0) else: net.running_mean = checkpoint["running_mean"] net.running_var = checkpoint["running_var"] return net
[docs] @cache def load_saved_ppo(path: Optional[Path] = None) -> PPO: """ Load saved weights of a PPO model from a given path, if no path was given the installed weights of the PPO agent will be loaded. """ if path is None: path = DEFAULT_SAVED_PPO_PATH return load_saved_model(path, PPO, dropout=DROPOUT)