"""
PPO Trainer for continuous control tasks.

Orchestrates the full training loop:
1. Collect rollouts from environment
2. Compute advantages using GAE
3. Update policy and value networks using PPO loss
4. Log metrics and save checkpoints

Based on CleanRL's PPO implementation with clean modular structure.
"""

from pathlib import Path
from datetime import datetime
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym

from ppo_agent import PPOAgent
from rollout_buffer import RolloutBuffer
from config import PPOConfig
from logger import Logger


class PPOTrainer:
    """
    Manages PPO training loop.
    
    Responsibilities:
    - Collect experience via environment interaction
    - Compute advantages and returns
    - Optimize policy using clipped objective
    - Log training progress and save checkpoints
    """
    
    def __init__(self, config: PPOConfig, agent: PPOAgent, envs: gym.vector.VectorEnv, run_dir: Path, device: torch.device):
        """
        Args:
            config: Training configuration (hyperparameters)
            agent: PPO agent to train
            envs: Vectorized environments
            run_dir: Output directory for logs/checkpoints
            device: Torch device (cpu/cuda)
        """
        self.config = config
        self.agent = agent
        self.envs = envs
        self.run_dir = run_dir
        self.device = device
        
        # Optimizer
        self.optimizer = optim.Adam(agent.parameters(), lr=config.learning_rate, eps=1e-5)
        
        # Rollout buffer
        self.buffer = RolloutBuffer(
            num_steps=config.num_steps,
            num_envs=config.num_envs,
            obs_shape=envs.single_observation_space.shape,
            action_shape=envs.single_action_space.shape,
            device=device,
            gamma=config.gamma,
            gae_lambda=config.gae_lambda,
        )
        
        # Logger
        self.logger = Logger(run_dir)
        
        # Training state
        self.global_step = 0
        self.start_time = time.time()
        
        # Checkpoint tracking
        self.last_checkpoint_step = 0
        self.last_eval_step = 0
        self.last_log_step = 0
        
        # Early stopping state
        self.best_val_return = float('-inf')
        self.patience_counter = 0
        self.should_stop = False
    
    def train(self) -> None:
        """
        Main training loop.
        
        Runs for config.total_timesteps, collecting rollouts and updating policy.
        """
        print("=" * 60)
        print(f"Starting PPO Training: {self.run_dir.name}")
        print("=" * 60)
        print(f"Environment: {self.config.env_id}")
        print(f"Total timesteps: {self.config.total_timesteps:,}")
        print(f"Num envs: {self.config.num_envs}")
        print(f"Steps per rollout: {self.config.num_steps}")
        print(f"Batch size: {self.config.num_envs * self.config.num_steps}")
        print(f"Learning rate: {self.config.learning_rate}")
        print(f"Minibatches: {self.config.num_minibatches}")
        print(f"Update epochs: {self.config.update_epochs}")
        print(f"GAE lambda: {self.config.gae_lambda}")
        print(f"Clip coef: {self.config.clip_coef}")
        print(f"Value coef: {self.config.vf_coef}")
        print(f"Entropy coef: {self.config.ent_coef}")
        print(f"Device: {self.device}")
        print(f"Save freq: {self.config.save_freq:,}")
        if self.config.early_stopping:
            print(f"Early stopping: Enabled (patience={self.config.patience}, eval_freq={self.config.eval_freq:,})")
        else:
            print(f"Early stopping: Disabled")
        print("=" * 60)
        
        # Initialize environment
        next_obs = torch.Tensor(self.envs.reset()[0]).to(self.device)
        next_done = torch.zeros(self.config.num_envs).to(self.device)
        
        num_updates = self.config.total_timesteps // (self.config.num_steps * self.config.num_envs)
        
        for update in range(1, num_updates + 1):
            # Learning rate annealing TODO: just use a torch scheduler
            if self.config.anneal_lr:
                frac = 1.0 - (update - 1.0) / num_updates
                lrnow = frac * self.config.learning_rate
                self.optimizer.param_groups[0]["lr"] = lrnow
            
            # Collect rollout
            next_obs, next_done = self._collect_rollout(next_obs, next_done)
            
            # Compute advantages
            with torch.no_grad():
                next_value = self.agent.get_value(next_obs).reshape(1, -1)
            self.buffer.compute_returns_and_advantages(next_value, next_done)
            
            # Update policy
            train_stats = self._update_policy()
            
            # Logging (log every N steps)
            if self.global_step - self.last_log_step >= self.config.log_freq or update == 1:
                self._log_metrics(train_stats)
                self.last_log_step = self.global_step
            
            # Checkpointing and validation
            # Check if we've crossed a checkpoint boundary
            if self.global_step - self.last_checkpoint_step >= self.config.save_freq or update == num_updates:
                self._save_checkpoint()
                self.last_checkpoint_step = self.global_step
                
                # Early stopping validation
                if self.config.early_stopping and (self.global_step - self.last_eval_step >= self.config.eval_freq or update == num_updates):
                    val_return = self._validate()
                    self.last_eval_step = self.global_step
                    
                    # Check for improvement
                    if val_return > self.best_val_return + self.config.min_improvement:
                        print(f"New best validation return: {val_return:.2f} (prev: {self.best_val_return:.2f})")
                        self.best_val_return = val_return
                        self.patience_counter = 0
                        self._save_checkpoint(best=True)
                    else:
                        self.patience_counter += 1
                        print(f"No improvement. Patience: {self.patience_counter}/{self.config.patience}")
                    
                    # Early stopping check
                    if self.patience_counter >= self.config.patience:
                        print(f"Early stopping triggered after {self.global_step:,} steps")
                        print(f"   Best validation return: {self.best_val_return:.2f}")
                        self.should_stop = True
            
            # Exit if early stopping triggered
            if self.should_stop:
                break
        
        # Final checkpoint
        self._save_checkpoint(final=True)
        
        self.envs.close()
        self.logger.close()
        
        print("=" * 60)
        if self.should_stop:
            print("Training stopped early!")
            print(f"Best validation return: {self.best_val_return:.2f}")
            print(f"Best checkpoint: {self.run_dir}/checkpoints/ckpt_best.pt")
        else:
            print("Training completed!")
        print(f"Run directory: {self.run_dir}")
        print("=" * 60)
    
    def _collect_rollout(self, next_obs: torch.Tensor, next_done: torch.Tensor):
        """
        Collect experience for one rollout.
        
        Args:
            next_obs: Starting observation
            next_done: Starting done flag
            
        Returns:
            Tuple of (final_obs, final_done)
        """
        for step in range(self.config.num_steps):
            self.global_step += self.config.num_envs
            
            # Get action and value from agent
            with torch.no_grad():
                action, logprob, _, value = self.agent.get_action_and_value(next_obs)
            
            # Step environment
            next_obs_np, reward, terminations, truncations, infos = self.envs.step(action.cpu().numpy())
            next_done_np = np.logical_or(terminations, truncations)
            
            # Store transition
            self.buffer.add(
                step=step,
                obs=next_obs,
                action=action,
                logprob=logprob,
                reward=torch.tensor(reward).to(self.device),
                done=next_done,
                value=value.flatten(),
            )
            
            # Update state
            next_obs = torch.Tensor(next_obs_np).to(self.device)
            next_done = torch.Tensor(next_done_np).to(self.device)
            
            # Log episode statistics
            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info and "episode" in info:
                        self.logger.log_scalar("charts/episodic_return", info["episode"]["r"], self.global_step)
                        self.logger.log_scalar("charts/episodic_length", info["episode"]["l"], self.global_step)
        
        return next_obs, next_done
    
    def _update_policy(self) -> dict:
        """
        Update policy using PPO objective.
        
        Returns:
            Dictionary of training statistics
        """
        batch = self.buffer.get_flat_batch()
        batch_size = self.config.num_steps * self.config.num_envs
        minibatch_size = batch_size // self.config.num_minibatches
        
        # Training metrics
        clipfracs = []
        pg_losses = []
        v_losses = []
        entropy_losses = []
        approx_kls = []
        
        b_inds = np.arange(batch_size)
        
        for epoch in range(self.config.update_epochs):
            np.random.shuffle(b_inds)
            
            for start in range(0, batch_size, minibatch_size):
                end = start + minibatch_size
                mb_inds = b_inds[start:end]
                
                # Forward pass
                _, newlogprob, entropy, newvalue = self.agent.get_action_and_value(
                    batch["obs"][mb_inds], batch["actions"][mb_inds]
                )
                logratio = newlogprob - batch["logprobs"][mb_inds]
                ratio = logratio.exp()
                
                # Metrics
                with torch.no_grad():
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs.append(((ratio - 1.0).abs() > self.config.clip_coef).float().mean().item())
                    approx_kls.append(approx_kl.item())
                
                # Normalize advantages
                mb_advantages = batch["advantages"][mb_inds]
                if self.config.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
                
                # Policy loss (clipped objective)
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - self.config.clip_coef, 1 + self.config.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()
                
                # Value loss (optionally clipped)
                newvalue = newvalue.view(-1)
                if self.config.clip_vloss:
                    v_loss_unclipped = (newvalue - batch["returns"][mb_inds]) ** 2
                    v_clipped = batch["values"][mb_inds] + torch.clamp(
                        newvalue - batch["values"][mb_inds], -self.config.clip_coef, self.config.clip_coef
                    )
                    v_loss_clipped = (v_clipped - batch["returns"][mb_inds]) ** 2
                    v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
                else:
                    v_loss = 0.5 * ((newvalue - batch["returns"][mb_inds]) ** 2).mean()
                
                entropy_loss = entropy.mean()
                
                # Total loss
                loss = pg_loss - self.config.ent_coef * entropy_loss + v_loss * self.config.vf_coef
                
                # Optimization step
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.agent.parameters(), self.config.max_grad_norm)
                self.optimizer.step()
                
                # Collect stats
                pg_losses.append(pg_loss.item())
                v_losses.append(v_loss.item())
                entropy_losses.append(entropy_loss.item())
            
            # Early stopping with KL divergence
            if self.config.target_kl is not None and np.mean(approx_kls) > self.config.target_kl:
                break
        
        # Explained variance
        y_pred, y_true = batch["values"].cpu().numpy(), batch["returns"].cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
        
        return {
            "pg_loss": np.mean(pg_losses),
            "v_loss": np.mean(v_losses),
            "entropy_loss": np.mean(entropy_losses),
            "approx_kl": np.mean(approx_kls),
            "clipfrac": np.mean(clipfracs),
            "explained_variance": explained_var,
        }
    
    def _log_metrics(self, train_stats: dict) -> None:
        """Log training metrics to TensorBoard and console."""
        lr = self.optimizer.param_groups[0]["lr"]
        
        self.logger.log_scalar("charts/learning_rate", lr, self.global_step)
        self.logger.log_scalar("losses/policy_loss", train_stats["pg_loss"], self.global_step)
        self.logger.log_scalar("losses/value_loss", train_stats["v_loss"], self.global_step)
        self.logger.log_scalar("losses/entropy", train_stats["entropy_loss"], self.global_step)
        self.logger.log_scalar("losses/approx_kl", train_stats["approx_kl"], self.global_step)
        self.logger.log_scalar("losses/clipfrac", train_stats["clipfrac"], self.global_step)
        self.logger.log_scalar("losses/explained_variance", train_stats["explained_variance"], self.global_step)
        
        sps = int(self.global_step / (time.time() - self.start_time))
        print(f"Step {self.global_step:,} | SPS: {sps} | PG Loss: {train_stats['pg_loss']:.3f} | V Loss: {train_stats['v_loss']:.3f}")
    
    def _validate(self) -> float:
        """
        Run validation episodes to evaluate current policy.
        
        Returns:
            Mean return over validation episodes
        """
        print(f"\n{'='*60}")
        print(f"Running validation at step {self.global_step:,}")
        print(f"{'='*60}")
        
        # Create a single evaluation environment
        eval_env = gym.make(self.config.env_id)
        eval_env = gym.wrappers.RecordEpisodeStatistics(eval_env)
        
        episode_returns = []
        
        for ep in range(self.config.eval_episodes):
            obs, _ = eval_env.reset(seed=self.config.seed + 10000 + ep)
            done = False
            truncated = False
            episode_return = 0.0
            
            while not (done or truncated):
                obs_tensor = torch.Tensor(obs).unsqueeze(0).to(self.device)
                with torch.no_grad():
                    action, _, _, _ = self.agent.get_action_and_value(obs_tensor)
                obs, reward, done, truncated, _ = eval_env.step(action.cpu().numpy()[0])
                episode_return += reward
            
            episode_returns.append(episode_return)
        
        eval_env.close()
        
        mean_return = float(np.mean(episode_returns))
        std_return = float(np.std(episode_returns))
        
        # Log to TensorBoard
        self.logger.log_scalar("validation/mean_return", mean_return, self.global_step)
        self.logger.log_scalar("validation/std_return", std_return, self.global_step)
        
        print(f"Validation results:")
        print(f"  Mean return: {mean_return:.2f} ± {std_return:.2f}")
        print(f"{'='*60}\n")
        
        return mean_return
    
    def _save_checkpoint(self, final: bool = False, best: bool = False) -> None:
        """Save agent checkpoint."""
        checkpoint_dir = self.run_dir / "checkpoints"
        checkpoint_dir.mkdir(exist_ok=True)
        
        if final:
            checkpoint_path = checkpoint_dir / "ckpt_final.pt"
        elif best:
            checkpoint_path = checkpoint_dir / "ckpt_best.pt"
        else:
            checkpoint_path = checkpoint_dir / f"ckpt_step_{self.global_step}.pt"
        
        torch.save({
            "global_step": self.global_step,
            "agent_state_dict": self.agent.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "config": self.config,
        }, checkpoint_path)
        
        print(f"Checkpoint saved: {checkpoint_path}")

