"""
Checkpointing utilities for saving and loading agent state.

Handles serialization of:
- Agent model weights
- Optimizer state
- Training metadata (global step, config)
- Legacy checkpoints with old src.* module paths
"""

from pathlib import Path
import sys
import torch

from ppo_agent import PPOAgent


def _install_module_redirects():
    """
    Install module redirects for loading old checkpoints with src.* paths.
    
    This allows loading checkpoints saved before the module restructuring.
    """
    # Only install if not already present
    if 'src' in sys.modules:
        return
    
    # Import all modules that might be needed
    import ppo_agent
    import baseline_agent
    import huggingface_agent
    import trainer
    import rollout_buffer
    import evaluator
    import metrics
    import env_factory
    import config
    import logger
    import video
    
    # Create dummy module for 'src' package
    sys.modules['src'] = sys.modules[__name__]
    
    # Map old module paths to new flat modules
    sys.modules['src.agents'] = sys.modules[__name__]
    sys.modules['src.agents.ppo_agent'] = ppo_agent
    sys.modules['src.agents.base'] = ppo_agent
    sys.modules['src.agents.baseline_agent'] = baseline_agent
    sys.modules['src.agents.huggingface_agent'] = huggingface_agent
    
    sys.modules['src.training'] = sys.modules[__name__]
    sys.modules['src.training.trainer'] = trainer
    sys.modules['src.training.rollout_buffer'] = rollout_buffer
    
    sys.modules['src.evaluation'] = sys.modules[__name__]
    sys.modules['src.evaluation.evaluator'] = evaluator
    sys.modules['src.evaluation.metrics'] = metrics
    
    sys.modules['src.envs'] = sys.modules[__name__]
    sys.modules['src.envs.env_factory'] = env_factory
    
    sys.modules['src.utils'] = sys.modules[__name__]
    sys.modules['src.utils.config'] = config
    sys.modules['src.utils.logging'] = logger
    sys.modules['src.utils.checkpointing'] = sys.modules[__name__]
    sys.modules['src.utils.video'] = video


def save_checkpoint(
    agent: PPOAgent,
    optimizer: torch.optim.Optimizer,
    global_step: int,
    path: Path,
    config=None,
) -> None:
    """
    Save training checkpoint.
    
    Args:
        agent: PPO agent
        optimizer: Optimizer
        global_step: Current training step
        path: Path to save checkpoint (.pt file)
        config: Optional training configuration
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    
    checkpoint = {
        "global_step": global_step,
        "agent_state_dict": agent.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    
    if config is not None:
        checkpoint["config"] = config
    
    torch.save(checkpoint, path)


def load_checkpoint(path: Path, agent: PPOAgent = None, optimizer: torch.optim.Optimizer = None, map_location="cpu"):
    """
    Load training checkpoint with support for legacy module paths.
    
    Handles checkpoints saved with old src.* module structure by installing
    module redirects before loading.
    
    Args:
        path: Path to checkpoint file
        agent: Optional agent to load state into
        optimizer: Optional optimizer to load state into
        map_location: Device to load checkpoint to (default: "cpu")
        
    Returns:
        Checkpoint dictionary with keys:
        - global_step
        - agent_state_dict
        - optimizer_state_dict
        - config (if saved)
    """
    # Install module redirects for legacy checkpoints
    _install_module_redirects()
    
    checkpoint = torch.load(path, map_location=map_location, weights_only=False)
    
    if agent is not None:
        agent.load_state_dict(checkpoint["agent_state_dict"])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    return checkpoint

