Source code for splendor.agents.our_agents.ppo.input_norm
import torch
import torch.nn as nn
from .constants import VERY_SMALL_EPSILON
[docs]
class InputNormalization(nn.Module):
def __init__(self, num_features, epsilon=VERY_SMALL_EPSILON):
super().__init__()
self.num_features = num_features
self.epsilon = epsilon
self.register_buffer("running_mean", torch.zeros(1, num_features))
self.register_buffer("running_var", torch.ones(1, num_features))
[docs]
def forward(self, x):
if self.training:
mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
n = x.numel() / x.size(1)
with torch.no_grad():
self.running_mean = self.running_mean * 0.9 + mean * 0.1
self.running_var = self.running_var * 0.9 + var * 0.1
else:
mean = self.running_mean
var = self.running_var
x_normalized = (x - mean) / torch.sqrt(var + self.epsilon)
return x_normalized