Source code for splendor.agents.our_agents.ppo.input_norm

"""
Implementation of an input normalization layer.
"""

import torch
from torch import nn

from .constants import VERY_SMALL_EPSILON


[docs] class InputNormalization(nn.Module): """ Input normalization layer - using a running average for calibrating the mean & variance. """ # pylint: disable=too-few-public-methods def __init__(self, num_features: int, epsilon: float = VERY_SMALL_EPSILON) -> None: """ Create a new input normalization layer. :param num_features: how many features to expect in the inputs. :param epsilon: which epsilon value should be add to the denominator during the normalization in order to avoid division by 0. """ super().__init__() self.num_features = num_features self.epsilon = epsilon self.register_buffer("running_mean", torch.zeros(1, num_features)) self.register_buffer("running_var", torch.ones(1, num_features)) self.running_mean: torch.Tensor self.running_var: torch.Tensor
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=attribute-defined-outside-init """ Normalize the input using a running mean & variance estimators. The output should have 0 mean and variance of 1. :param x: the un-normalized input. :return: a normalized x. """ if self.training: mean = x.mean(dim=0) var = x.var(dim=0, unbiased=False) with torch.no_grad(): self.running_mean = self.running_mean * 0.9 + mean * 0.1 self.running_var = self.running_var * 0.9 + var * 0.1 else: mean = self.running_mean var = self.running_var x_normalized = (x - mean) / torch.sqrt(var + self.epsilon) return x_normalized