from typing import Optional
from pathlib import Path
from functools import cache
import torch
import gymnasium as gym
from gymnasium.spaces.utils import flatdim
from .ppo_base import PPOBase, PPOBaseFactory
from .network import PPO, DROPOUT
DEFAULT_SAVED_PPO_PATH = Path(__file__).parent / "ppo_model.pth"
def load_saved_model(
path: Path,
ppo_factory: PPOBaseFactory,
) -> PPOBase:
Load saved weights of a PPO model from a given path, if no path was given
the installed weights of the PPO agent will be loaded.
env = gym.make("splendor-v1", agents=[])
# load_weights
net = ppo_factory(
flatdim(env.observation_space), flatdim(env.action_space), *args, **kwargs
checkpoint = torch.load(
if hasattr(net, "input_norm"):
# both running_mean & running_var are stored as (1, flatdim(env.observation_space))
# rather than (flatdim(env.observation_space),)
net.input_norm.running_mean = checkpoint["running_mean"].squeeze(0)
net.input_norm.running_var = checkpoint["running_var"].squeeze(0)
net.running_mean = checkpoint["running_mean"]
net.running_var = checkpoint["running_var"]
return net
def load_saved_ppo(path: Optional[Path] = None) -> PPO:
Load saved weights of a PPO model from a given path, if no path was given
the installed weights of the PPO agent will be loaded.
if path is None:
return load_saved_model(path, PPO, dropout=DROPOUT)