"""
Policy evaluator for running evaluation episodes.

Runs an agent in an environment for multiple episodes and collects
performance statistics (returns, lengths, etc.).
"""

from pathlib import Path
from typing import Optional, List, Tuple
import numpy as np
import gymnasium as gym

from metrics import compute_metrics


class PolicyEvaluator:
    """
    Evaluates an agent's policy on an environment.
    
    Runs evaluation episodes (typically with deterministic actions)
    and computes performance metrics.
    """
    
    def __init__(self, env: gym.Env, deterministic: bool = True):
        """
        Args:
            env: Gymnasium environment
            deterministic: If True, use deterministic policy during eval
        """
        self.env = env
        self.deterministic = deterministic
    
    def evaluate(self, agent, num_episodes: int) -> Tuple[dict, List[float], List[int]]:
        """
        Evaluate agent for multiple episodes.
        
        Args:
            agent: Agent to evaluate
            num_episodes: Number of evaluation episodes
            
        Returns:
            Tuple of:
            - metrics: Dictionary of performance statistics
            - returns: List of episode returns
            - lengths: List of episode lengths
        """
        returns = []
        lengths = []
        
        for ep in range(num_episodes):
            obs, _ = self.env.reset()
            done = False
            episode_return = 0.0
            episode_length = 0
            
            while not done:
                action = agent.get_action(obs, deterministic=self.deterministic)
                obs, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated
                
                episode_return += reward
                episode_length += 1
            
            returns.append(episode_return)
            lengths.append(episode_length)
        
        metrics = compute_metrics(returns, lengths)
        return metrics, returns, lengths
    
    def evaluate_and_save(self, agent, num_episodes: int, output_path: Optional[Path] = None) -> dict:
        """
        Evaluate agent and optionally save results to JSON.
        
        Args:
            agent: Agent to evaluate
            num_episodes: Number of evaluation episodes
            output_path: Optional path to save metrics JSON
            
        Returns:
            Dictionary of metrics
        """
        metrics, returns, lengths = self.evaluate(agent, num_episodes)
        
        if output_path is not None:
            import json
            output_path.parent.mkdir(parents=True, exist_ok=True)
            with open(output_path, "w") as f:
                json.dump({
                    "metrics": metrics,
                    "returns": returns,
                    "lengths": lengths,
                }, f, indent=2)
        
        return metrics

