"""
Main training script for PPO on Pusher-v5.

Usage:
    python train.py --exp-name my_experiment --total-timesteps 1000000
    
This is a thin orchestration layer that:
1. Creates configuration
2. Sets up environment and agent
3. Instantiates trainer
4. Runs training loop
"""

import argparse
import random
from pathlib import Path
from datetime import datetime
import numpy as np
import torch

from config import PPOConfig
from ppo_agent import PPOAgent
from env_factory import make_vec_envs
from trainer import PPOTrainer


# ------------------------------------------------------------
# Main
# ------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Train PPO agent on Pusher-v5")
    
    # Experiment
    parser.add_argument("--exp-name", type=str, default=None, help="Experiment name")
    parser.add_argument("--seed", type=int, default=None, help="Random seed")
    
    # Training
    parser.add_argument("--total-timesteps", type=int, default=None, help="Total training steps")
    parser.add_argument("--learning-rate", type=float, default=None, help="Learning rate")
    parser.add_argument("--num-envs", type=int, default=None, help="Number of parallel environments")
    parser.add_argument("--num-steps", type=int, default=None, help="Steps per rollout")
    parser.add_argument("--save-freq", type=int, default=None, help="Checkpoint frequency")
    
    # Early stopping
    parser.add_argument("--no-early-stopping", action="store_true", help="Disable early stopping")
    parser.add_argument("--patience", type=int, default=None, help="Early stopping patience")
    parser.add_argument("--eval-episodes", type=int, default=None, help="Episodes per validation")
    parser.add_argument("--min-improvement", type=float, default=None, help="Min improvement to reset patience")
    
    # Device
    parser.add_argument("--no-cuda", action="store_true", help="Disable CUDA")
    
    args = parser.parse_args()
    
    # Build config kwargs from CLI args (only if provided)
    config_kwargs = {}
    if args.exp_name is not None:
        config_kwargs['exp_name'] = args.exp_name
    if args.seed is not None:
        config_kwargs['seed'] = args.seed
    if args.total_timesteps is not None:
        config_kwargs['total_timesteps'] = args.total_timesteps
    if args.learning_rate is not None:
        config_kwargs['learning_rate'] = args.learning_rate
    if args.num_envs is not None:
        config_kwargs['num_envs'] = args.num_envs
    if args.num_steps is not None:
        config_kwargs['num_steps'] = args.num_steps
    if args.save_freq is not None:
        config_kwargs['save_freq'] = args.save_freq
    if args.no_cuda:
        config_kwargs['cuda'] = False
    if args.no_early_stopping:
        config_kwargs['early_stopping'] = False
    if args.patience is not None:
        config_kwargs['patience'] = args.patience
    if args.eval_episodes is not None:
        config_kwargs['eval_episodes'] = args.eval_episodes
    if args.min_improvement is not None:
        config_kwargs['min_improvement'] = args.min_improvement
    
    # Create config (uses config.py defaults unless overridden by CLI)
    config = PPOConfig(**config_kwargs)
    
    # Seeding
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.backends.cudnn.deterministic = config.torch_deterministic
    
    # Device
    device = torch.device("cuda" if torch.cuda.is_available() and config.cuda else "cpu")
    
    # Run directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_name = f"{config.exp_name}_{timestamp}"
    run_dir = Path("runs") / run_name
    run_dir.mkdir(parents=True, exist_ok=True)
    
    # Save config
    import json
    with open(run_dir / "config.json", "w") as f:
        json.dump(config.__dict__, f, indent=2)
    
    # Create environment
    envs = make_vec_envs(config.env_id, config.num_envs, config.seed)
    
    # Create agent
    obs_dim = int(np.prod(envs.single_observation_space.shape))
    action_dim = int(np.prod(envs.single_action_space.shape))
    agent = PPOAgent(obs_dim, action_dim).to(device)
    
    # Create trainer
    trainer = PPOTrainer(
        config=config,
        agent=agent,
        envs=envs,
        run_dir=run_dir,
        device=device,
    )
    
    # Train
    trainer.train()


if __name__ == "__main__":
    main()

