"""
Agent wrapper for loading pretrained models from HuggingFace Hub.

Loads Stable-Baselines3 models (.zip format) for use as benchmarks.

Note: Minari datasets are NOT agents - they contain recorded trajectories,
not policies. Use the Minari data directly for analysis/comparison, not as agents.
"""

from pathlib import Path
import numpy as np
import gymnasium as gym

class HuggingFaceAgent:
    """
    Wrapper for loading pretrained Stable-Baselines3 models from HuggingFace Hub.
    
    Downloads and instantiates a trained policy for evaluation as a benchmark.
    """
    
    def __init__(self, model_id: str):
        """
        Args:
            model_id: HuggingFace repository ID (e.g., "farama-minari/Pusher-v5-SAC-medium")
        """
        self.model_id = model_id
        self.model = None
        self._load_model()
    
    def _load_model(self) -> None:
        """Load Stable-Baselines3 model from HuggingFace Hub."""
        from huggingface_sb3 import load_from_hub
        from huggingface_hub import list_repo_files
        from stable_baselines3 import SAC
        
        print(f"Loading model from HuggingFace: {self.model_id}")
        
        # List files in repo to auto-detect .zip file
        files = list_repo_files(self.model_id)
        zip_files = [f for f in files if f.endswith(".zip")]
        
        if not zip_files:
            raise ValueError(f"No .zip model file found in {self.model_id}")
        
        model_filename = zip_files[0]
        print(f"  Detected model file: {model_filename}")
        
        # Download and load model
        checkpoint = load_from_hub(
            repo_id=self.model_id,
            filename=model_filename,
        )
        
        self.model = SAC.load(checkpoint)
        print(f"Model loaded: {self.model_id}")
    
    def get_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        """
        Get action from loaded policy.
        
        Args:
            obs: Current observation
            deterministic: If True, use deterministic policy (mean action)
            
        Returns:
            Action array
        """
        action, _ = self.model.predict(obs, deterministic=deterministic)
        return action
    
    def save(self, path: Path) -> None:
        """Pretrained models are immutable - cannot be saved after loading."""
        print(f"Warning: HuggingFace models cannot be saved to {path}")
    
    def load(self, path: Path) -> None:
        """Models are loaded from HuggingFace Hub, not from local disk."""
        print(f"Warning: HuggingFace models are loaded from Hub, ignoring {path}")

