"""
Rollout buffer for storing experience and computing advantages.

The buffer stores transitions during environment interaction and
computes Generalized Advantage Estimation (GAE) for policy updates.

GAE balances bias-variance tradeoff in advantage estimation using:
- γ (gamma): discount factor for future rewards
- λ (lambda): GAE smoothing parameter
"""

import torch
import numpy as np


class RolloutBuffer:
    """
    Storage for rollout data and GAE computation.
    
    Stores:
    - Observations, actions, log probs
    - Rewards, dones, values
    
    Computes:
    - Advantages using GAE
    - Returns (advantages + values)
    """
    
    def __init__(self, num_steps: int, num_envs: int, obs_shape: tuple, action_shape: tuple, device: torch.device, gamma: float = 0.99, gae_lambda: float = 0.95):
        """
        Args:
            num_steps: Steps per rollout
            num_envs: Number of parallel environments
            obs_shape: Observation space shape
            action_shape: Action space shape
            device: Torch device (cpu/cuda)
            gamma: Discount factor
            gae_lambda: GAE lambda parameter
        """
        self.num_steps = num_steps
        self.num_envs = num_envs
        self.device = device
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        
        # Storage tensors
        self.obs = torch.zeros((num_steps, num_envs) + obs_shape).to(device)
        self.actions = torch.zeros((num_steps, num_envs) + action_shape).to(device)
        self.logprobs = torch.zeros((num_steps, num_envs)).to(device)
        self.rewards = torch.zeros((num_steps, num_envs)).to(device)
        self.dones = torch.zeros((num_steps, num_envs)).to(device)
        self.values = torch.zeros((num_steps, num_envs)).to(device)
        
        # Computed during GAE
        self.advantages = None
        self.returns = None
    
    def add(self, step: int, obs, action, logprob, reward, done, value):
        """
        Add a transition to the buffer.
        
        Args:
            step: Current step index (0 to num_steps-1)
            obs: Observation tensor
            action: Action tensor
            logprob: Log probability of action
            reward: Reward scalar/tensor
            done: Done flag
            value: Value estimate
        """
        self.obs[step] = obs
        self.actions[step] = action
        self.logprobs[step] = logprob
        self.rewards[step] = reward
        self.dones[step] = done
        self.values[step] = value
    
    def compute_returns_and_advantages(self, next_value: torch.Tensor, next_done: torch.Tensor):
        """
        Compute advantages using Generalized Advantage Estimation (GAE).
        
        GAE formula:
            A_t = δ_t + (γλ)δ_{t+1} + (γλ)^2 δ_{t+2} + ...
        where:
            δ_t = r_t + γV(s_{t+1}) - V(s_t)
        
        Args:
            next_value: Value estimate for next state (after rollout)
            next_done: Done flag for next state
        """
        with torch.no_grad():
            advantages = torch.zeros_like(self.rewards).to(self.device)
            lastgaelam = 0
            
            for t in reversed(range(self.num_steps)):
                if t == self.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - self.dones[t + 1]
                    nextvalues = self.values[t + 1]
                
                # TD error
                delta = self.rewards[t] + self.gamma * nextvalues * nextnonterminal - self.values[t]
                
                # GAE accumulation
                advantages[t] = lastgaelam = delta + self.gamma * self.gae_lambda * nextnonterminal * lastgaelam
            
            self.advantages = advantages
            self.returns = advantages + self.values
    
    def get_flat_batch(self):
        """
        Flatten rollout data into batch format for training.
        
        Returns:
            Dictionary with flattened tensors:
            - obs, actions, logprobs, advantages, returns, values
        """
        batch_size = self.num_steps * self.num_envs
        
        return {
            "obs": self.obs.reshape(batch_size, -1),
            "actions": self.actions.reshape(batch_size, -1),
            "logprobs": self.logprobs.reshape(-1),
            "advantages": self.advantages.reshape(-1),
            "returns": self.returns.reshape(-1),
            "values": self.values.reshape(-1),
        }

