from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import torch
from .common import (
calculate_returns,
calculate_advantages,
)
[docs]
@dataclass
class RolloutBuffer:
size: int
input_dim: int
action_dim: int
is_recurrent: bool = False
hidden_states_dim: Optional[int] = None
device: Optional[torch.device] = None
states: torch.Tensor = field(init=False)
actions: torch.Tensor = field(init=False)
action_mask_history: torch.Tensor = field(init=False)
log_prob_actions: torch.Tensor = field(init=False)
values: torch.Tensor = field(init=False)
rewards: torch.Tensor = field(init=False)
dones: torch.Tensor = field(init=False)
hidden_states: Optional[torch.Tensor] = field(init=False)
index: int = field(default=0, init=False)
full: bool = field(default=False, init=False)
def __post_init__(self):
self.index = 0
self.full = False
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.states = torch.zeros((self.size, self.input_dim), dtype=torch.float64).to(
self.device
)
self.actions = torch.zeros((self.size, 1), dtype=torch.float64).to(self.device)
self.action_mask_history = torch.zeros(
(self.size, self.action_dim), dtype=torch.float64
).to(self.device)
self.log_prob_actions = torch.zeros((self.size, 1), dtype=torch.float64).to(
self.device
)
self.values = torch.zeros((self.size, 1), dtype=torch.float64).to(self.device)
self.rewards = torch.zeros((self.size, 1), dtype=torch.float64).to(self.device)
self.dones = torch.zeros((self.size, 1), dtype=torch.bool).to(self.device)
if self.is_recurrent:
if self.hidden_states_dim is None:
raise ValueError(
"hidden_states_dim must be an int when is_recurrent is set"
)
self.hidden_states = torch.zeros(
(self.size, 1, self.hidden_states_dim), dtype=torch.float64
).to(self.device)
else:
self.hidden_states = torch.zeros(self.size, dtype=torch.float64)
[docs]
def remember(
self,
state: torch.Tensor,
action: torch.Tensor,
action_mask: torch.Tensor,
log_prob_action: torch.Tensor,
value: float,
reward: float,
done: bool,
hidden_state: Optional[torch.Tensor] = None,
):
with torch.no_grad():
if self.full:
return
self.states[self.index] = state
self.actions[self.index] = action
self.action_mask_history[self.index] = action_mask
self.log_prob_actions[self.index] = log_prob_action
self.values[self.index] = value
self.rewards[self.index] = reward
self.dones[self.index] = done
if self.is_recurrent:
self.hidden_states[self.index] = hidden_state
self.index += 1
if self.index >= self.size:
self.full = True
[docs]
def clear(self):
self.index = 0
self.full = False
[docs]
def calculate_gae(
self, discount_factor: float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute the Generalized Advantage Estimation (GAE).
"""
with torch.no_grad():
returns = calculate_returns(self.rewards[: self.index], discount_factor).to(
self.device
)
advantages = calculate_advantages(returns, self.values[: self.index]).to(
self.device
)
return advantages, returns
[docs]
def unpack(self, discount_factor: float) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
hidden_states = self.hidden_states[: self.index]
states = self.states[: self.index]
actions = self.actions[: self.index]
action_masks = self.action_mask_history[: self.index]
log_prob_actions = self.log_prob_actions[: self.index]
dones = self.dones[: self.index]
advantages, returns = self.calculate_gae(discount_factor)
return (
hidden_states,
states,
actions,
action_masks,
log_prob_actions,
advantages,
returns,
dones,
)