"""
Configuration dataclasses for training and evaluation.

Defines PPOConfig with all hyperparameters and settings for PPO training.
"""

from dataclasses import dataclass, field
from typing import Optional


@dataclass
class PPOConfig:
    """
    Configuration for PPO training.
    
    Includes:
    - Environment settings
    - Algorithm hyperparameters
    - Training loop settings
    - Logging and checkpointing
    """
    
    # Experiment
    exp_name: str = "ppo_pusher"
    seed: int = 1
    torch_deterministic: bool = True
    cuda: bool = True
    
    # Environment
    env_id: str = "Pusher-v5"
    num_envs: int = 16
    
    # Algorithm hyperparameters
    total_timesteps: int = 2_000_000
    learning_rate: float = 5e-4
    num_steps: int = 2048  # Steps per rollout
    anneal_lr: bool = True
    gamma: float = 0.99  # Discount factor
    gae_lambda: float = 0.95  # GAE lambda
    num_minibatches: int = 32
    update_epochs: int = 10
    norm_adv: bool = True  # Normalize advantages
    clip_coef: float = 0.2  # PPO clipping coefficient
    clip_vloss: bool = True  # Clip value function loss
    ent_coef: float = 0.0  # Entropy coefficient
    vf_coef: float = 1  # Value function coefficient
    max_grad_norm: float = 0.5  # Gradient clipping
    target_kl: Optional[float] = None  # Early stopping KL threshold
    
    # Checkpointing
    save_freq: int = 100_000  # Save checkpoint every N steps
    
    # Logging
    log_freq: int = 1  # Log every N steps (every update with num_steps=2048)
    
    # Early stopping
    early_stopping: bool = True  # Enable validation-based early stopping
    eval_freq: int = 100_000  # Evaluate every N steps (should match save_freq)
    eval_episodes: int = 20  # Number of episodes to evaluate
    patience: int = 3  # Number of evaluations without improvement before stopping
    min_improvement: float = 1.0  # Minimum improvement in mean return to reset patience
    
    def __post_init__(self):
        """Validate configuration."""
        assert self.num_steps % self.num_minibatches == 0, \
            f"num_steps ({self.num_steps}) must be divisible by num_minibatches ({self.num_minibatches})"

