Source code for splendor.agents.our_agents.ppo.common

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distributions


from .constants import (
    ENTROPY_COEFFICIENT,
    VALUE_COEFFICIENT,
    VERY_SMALL_EPSILON,
)


[docs] def calculate_returns(rewards, discount_factor, normalize=True): returns = [] R = 0 for r in reversed(rewards): R = r + R * discount_factor returns.insert(0, R) returns = torch.tensor(returns) if normalize: # avoid possible division by 0 returns = (returns - returns.mean()) / (returns.std() + VERY_SMALL_EPSILON) return returns
[docs] def calculate_advantages(returns, values, normalize=True): advantages = returns - values if normalize: # avoid possible division by 0 advantages = (advantages - advantages.mean()) / ( advantages.std() + VERY_SMALL_EPSILON ) return advantages
[docs] def calculate_policy_loss( action_prob: torch.Tensor, actions: torch.Tensor, log_prob_actions: torch.Tensor, advantages: torch.Tensor, ppo_clip, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: dist = distributions.Categorical(action_prob) # new log prob using old actions new_log_prob_actions = dist.log_prob(actions) policy_ratio = (new_log_prob_actions - log_prob_actions).exp() policy_loss_1 = policy_ratio * advantages policy_loss_2 = ( torch.clamp(policy_ratio, min=1.0 - ppo_clip, max=1.0 + ppo_clip) * advantages ) policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean() kl_divergence_estimate = ( (log_prob_actions - new_log_prob_actions).mean().detach().cpu() ) # entropy bonus - use to improve exploration. # as seen here (bullet #10): # https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ entropy = dist.entropy().mean() return policy_loss, kl_divergence_estimate, entropy
[docs] def calculate_loss( policy_loss: torch.Tensor, value_loss: torch.Tensor, entropy_bonus: torch.Tensor ) -> torch.Tensor: """ final loss of clipped objective PPO, as seen here: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L91 """ loss = ( policy_loss + VALUE_COEFFICIENT * value_loss - ENTROPY_COEFFICIENT * entropy_bonus ) return loss