"""
PPO Actor-Critic agent for continuous control.

Implements a trainable policy network (actor) and value network (critic)
using the Proximal Policy Optimization algorithm. The agent learns from
experience through gradient-based optimization.

Architecture:
- Actor: 23-dim obs → 64 → 64 → 7-dim action mean + learned log std
- Critic: 23-dim obs → 64 → 64 → 1-dim value
"""

from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.normal import Normal



# ------------------------------------------------------------
# Helper Functions
# ------------------------------------------------------------

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    """Initialize layer weights with orthogonal initialization."""
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


# ------------------------------------------------------------
# PPO Agent
# ------------------------------------------------------------

class PPOAgent(nn.Module):
    """
    PPO Actor-Critic agent.
    
    Combines a policy network (actor) that selects actions and a value network
    (critic) that estimates state values for advantage computation.
    """
    
    def __init__(self, obs_dim: int, action_dim: int, hidden_dim: int = 64):
        """
        Args:
            obs_dim: Observation space dimensionality
            action_dim: Action space dimensionality
            hidden_dim: Hidden layer size (default: 64)
        """
        super().__init__()
        
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        
        # Critic (value function): obs → value
        self.critic = nn.Sequential(
            layer_init(nn.Linear(obs_dim, hidden_dim)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_dim, hidden_dim)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_dim, 1), std=1.0),
        )
        
        # Actor (policy): obs → action mean
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(obs_dim, hidden_dim)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_dim, hidden_dim)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_dim, action_dim), std=0.01),
        )
        
        # Learned log standard deviation for action distribution
        self.actor_logstd = nn.Parameter(torch.zeros(1, action_dim))
    
    def get_value(self, x: torch.Tensor) -> torch.Tensor:
        """
        Get state value estimate.
        
        Args:
            x: Observation tensor
            
        Returns:
            Value estimate
        """
        return self.critic(x)
    
    def get_action_and_value(self, x: torch.Tensor, action: torch.Tensor = None):
        """
        Get action, log probability, entropy, and value.
        
        Used during both rollout collection and policy updates.
        
        Args:
            x: Observation tensor
            action: Optional action tensor (for recomputing log prob during updates)
            
        Returns:
            Tuple of (action, log_prob, entropy, value)
        """
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        
        if action is None:
            action = probs.sample()
        
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
    
    def get_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        """
        Select action given observation (Agent interface).
        
        Args:
            obs: Observation array
            deterministic: If True, return mean action (no sampling)
            
        Returns:
            Action array
        """
        with torch.no_grad():
            obs_tensor = torch.tensor(obs, dtype=torch.float32)
            if obs_tensor.dim() == 1:
                obs_tensor = obs_tensor.unsqueeze(0)
            
            action_mean = self.actor_mean(obs_tensor)
            
            if deterministic:
                action = action_mean
            else:
                action_logstd = self.actor_logstd.expand_as(action_mean)
                action_std = torch.exp(action_logstd)
                probs = Normal(action_mean, action_std)
                action = probs.sample()
            
            return action.cpu().numpy().squeeze()
    
    def save(self, path: Path) -> None:
        """
        Save agent state to disk.
        
        Args:
            path: File path to save to (.pt file)
        """
        torch.save({
            "obs_dim": self.obs_dim,
            "action_dim": self.action_dim,
            "state_dict": self.state_dict(),
        }, path)
    
    def load(self, path: Path) -> None:
        """
        Load agent state from disk.
        
        Args:
            path: File path to load from (.pt file)
        """
        # Import here to avoid circular dependency
        from checkpointing import load_checkpoint
        
        checkpoint = load_checkpoint(path, agent=None, map_location="cpu")
        
        # Handle both old and new checkpoint formats
        if "state_dict" in checkpoint:
            self.load_state_dict(checkpoint["state_dict"])
        elif "agent_state_dict" in checkpoint:
            self.load_state_dict(checkpoint["agent_state_dict"])
        else:
            raise KeyError(f"Checkpoint must contain 'state_dict' or 'agent_state_dict', got keys: {checkpoint.keys()}")

