"""
Main orchestration module for loss-landscape animations.

The high-level steps to produce an animation are:

1. Load data via a PyTorch Lightning `LightningDataModule`.
2. Create a PyTorch Lightning model (`LightningModule`).
3. Record flattened parameters during training.
4. Use PCA's top 2 principal components to project parameters to 2D.
5. Evaluate a 2D slice of the loss landscape on that PCA plane.
6. Animate the optimization trajectory on top of the loss slice (2D contour or 3D surface).
"""
import argparse
import pathlib
import random

import numpy as np
import pytorch_lightning as pl
import torch

from plot import (
    animate_loss_surface_3d,
    animate_loss_surface_multi_3d,
    sample_frames,
)
from datamodule import MNISTDataModule, SpiralsDataModule
from loss_landscape import DimensionalityReduction, LossGrid
from model import MLP, LeNet


def set_global_seed(seed: int | None) -> None:
    """
    Set random seeds for Python, NumPy, PyTorch, and PyTorch Lightning.

    Parameters
    ----------
    seed : int or None
        Integer seed. If ``None``, this is a no-op.
    """
    if seed is None:
        return
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # Let Lightning handle any additional seeding (e.g., dataloaders)
    pl.seed_everything(seed, workers=True)

def loss_landscape_anim(
    num_epochs: int,
    datamodule: pl.LightningDataModule | None = None,
    model: pl.LightningModule | None = None,
    optimizer: str = "adam",
    model_dirpath: str = "checkpoints/",
    model_filename: str = "model.pt",
    gpus: int = 0,
    load_model: bool = False,
    output_to_file: bool = True,
    output_filename: str = "sample.gif",
    giffps: int = 15,
    sampling: bool = False,
    num_frames: int = 300,
    seed: int | None = None,
    return_data: bool = False,
    view: str = "3d",
    elev: int = 50,
    azim: int = -40,
):
    """
    Create an optimization animation in the loss landscape.

    Parameters
    ----------
    num_epochs : int
        Number of epochs to train.
    datamodule : LightningDataModule or None, optional
        PyTorch Lightning data module. If None, defaults to
        :class:`SpiralsDataModule`.
    model : LightningModule or None, optional
        PyTorch Lightning model. If None, defaults to an
        :class:`MLP` with 1 hidden layer and 50 neurons.
    optimizer : str, optional
        Name of the optimizer to use ("adam", "sgd",
        "sgd+momentum"). Default is "adam".
    model_dirpath : str, optional
        Directory to save the model. Default is "checkpoints/".
    model_filename : str, optional
        Model filename. Default is "model.pt".
    gpus : int, optional
        Number of GPUs to use. If 0, run on CPU. Default is 0.
    load_model : bool, optional
        If True, load from a previously trained model instead of
        training. Default is False.
    output_to_file : bool, optional
        If True, write the GIF to file. Default is True.
    output_filename : str, optional
        Output GIF filename. Default is "sample.gif".
    giffps : int, optional
        Frames per second for the GIF. Default is 15.
    sampling : bool, optional
        Whether to uniformly sample from training steps when there are
        too many steps. Default is False.
    num_frames : int, optional
        Maximum number of frames in the animation (after sampling).
        Default is 300.
    seed : int or None, optional
        Integer seed for reproducible experiments. If None, no
        seeding is applied.
    return_data : bool, optional
            If True, return the training steps for inspection.
    view : {"3d"}, optional
        Currently only "3d" is supported (3D surface animation).
    elev : int, optional
        Elevation angle for 3D view (only used when view="3d").
    azim : int, optional
        Azimuth angle for 3D view (only used when view="3d").

    Returns
    -------
    tuple[list[torch.Tensor], list[float], list[float]] or None
        If return_data is True, returns
        (optim_path, loss_path, acc_path) where each element is a
        list over training steps. Otherwise returns None.
    """
    set_global_seed(seed)

    if datamodule is None:
        print("Data module not provided, using sample data: spirals dataset")
        datamodule = SpiralsDataModule()

    if model is None and not load_model:
        print(
            "Model not provided, using default model: "
            "MLP with 1 hidden layer of 50 neurons"
        )
        model = MLP(
            input_dim=datamodule.input_dim,
            num_classes=datamodule.num_classes,
            learning_rate=5e-3,
            optimizer=optimizer,
            gpus=gpus,
        )

    model_dir = pathlib.Path(model_dirpath)
    if not model_dir.is_dir():
        model_dir.mkdir(parents=True, exist_ok=True)
        print(f"Model directory {model_dir.absolute()} does not exist, creating now.")
    file_path = model_dir / model_filename

    if gpus > 0:
        print("======== Using GPU for training ========")

    if not load_model:
        model.gpus = gpus
        train_loader = datamodule.train_dataloader()
        trainer = pl.Trainer(enable_progress_bar=True, max_epochs=num_epochs, gradient_clip_val=1.0)
        print(f"Training for {num_epochs} epochs...")
        trainer.fit(model, train_loader)
        torch.save(model, str(file_path))
        print(f"Model saved at {pathlib.Path(file_path).absolute()}.")
    else:
        print(f"Loading model from {pathlib.Path(file_path).absolute()}")

    if not file_path.is_file():
        raise Exception("Model file not found!")

    # Load the model on CPU to avoid exhausting MPS/GPU memory during analysis.
    # The subsequent loss landscape computation does not require GPU acceleration.
    model = torch.load(str(file_path), map_location="cpu")
    sampled_optim_path = sample_frames(model.optim_path, max_frames=num_frames)
    optim_path, loss_path, accu_path = zip(
        *[
            (path["flattened_params"], path["loss"], path["accuracy"])
            for path in sampled_optim_path
        ]
    )

    print(f"\n# sampled steps in optimization path: {len(optim_path)}")

    print(f"Dimensionality reduction via PCA")
    dim_reduction = DimensionalityReduction(params_path=optim_path, seed=seed)
    reduced_dict = dim_reduction.reduce()
    path_2d = reduced_dict["path_2d"]
    directions = reduced_dict["reduced_dirs"]
    pcvariances = reduced_dict.get("pcvariances")

    loss_grid = LossGrid(
        optim_path=optim_path,
        model=model,
        data=datamodule.dataset.tensors,
        path_2d=path_2d,
        directions=directions,
    )

    if view != "3d":
        raise ValueError("Only '3d' view is currently supported.")

    # 3D animation of the loss surface with trajectory.
    # Use log-loss surface (for dynamic range) but display raw loss values in text.
    animate_loss_surface_3d(
        param_steps=path_2d,
        loss_steps=loss_path,
        loss_grid=loss_grid.loss_values_log_2d,
        coords=loss_grid.coords,
        pcvariances=pcvariances,
        giffps=giffps,
        sampling=sampling,
        max_frames=num_frames,
        elev=elev,
        azim=azim,
        output_to_file=output_to_file,
        filename=output_filename,
    )

    if return_data:
        return list(optim_path), list(loss_path), list(accu_path)


def _build_default_spirals_datamodule(batch_size: int = 64) -> SpiralsDataModule:
    """
    Construct a default SpiralsDataModule.

    Parameters
    ----------
    batch_size : int, optional
        Batch size for the spirals dataset. Default is 64.

    Returns
    -------
    SpiralsDataModule
        A configured spirals data module.
    """
    return SpiralsDataModule(batch_size=batch_size)


def _parse_args() -> argparse.Namespace:
    """
    Parse command line arguments for the minimal CLI.

    Returns
    -------
    argparse.Namespace
        Parsed command-line arguments.
    """
    parser = argparse.ArgumentParser(
        description="Loss landscape animations for simple models."
    )
    parser.add_argument(
        "--mode",
        choices=["single", "multi"],
        default="single",
        help="Experiment mode: 'single' optimizer or 'multi' optimizer comparison.",
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        default="adam",
        help="Optimizer name for single mode (adam, sgd, sgd+momentum).",
    )
    parser.add_argument(
        "--num-epochs",
        type=int,
        default=50,
        help="Number of training epochs.",
    )
    parser.add_argument(
        "--view",
        choices=["2d", "3d"],
        default="2d",
        help="Type of visualization: 2d contour or 3d surface.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility.",
    )
    parser.add_argument(
        "--output-filename",
        type=str,
        default=None,
        help="Output GIF filename (defaults depend on mode).",
    )
    return parser.parse_args()


def _run_from_cli() -> None:
    """
    Run the script from the command line using parsed arguments.

    This is the entry point used when executing python main.py.

    Returns
    -------
    None
    """
    args = _parse_args()

    if args.mode == "single":
        output_filename = (
            args.output_filename or "spirals_single.gif"
        )
        loss_landscape_anim(
            num_epochs=args.num_epochs,
            datamodule=_build_default_spirals_datamodule(batch_size=64),
            optimizer=args.optimizer,
            output_to_file=True,
            output_filename=output_filename,
            view=args.view,
            seed=args.seed,
        )
    else:
        output_filename = (
            args.output_filename or "optimizers_spirals_multi.gif"
        )
        datamodule = _build_default_spirals_datamodule(batch_size=64)

        def mlp_factory(opt_name: str):
            return MLP(
                input_dim=datamodule.input_dim,
                num_classes=datamodule.num_classes,
                learning_rate=5e-3,
                optimizer=opt_name,
                gpus=0,
            )

        compare_optimizers_loss_landscape(
            num_epochs=args.num_epochs,
            optimizer_names=("adam", "sgd", "sgd+momentum"),
            datamodule=datamodule,
            model_factory=mlp_factory,
            giffps=15,
            num_frames=300,
            seed=args.seed,
            output_to_file=True,
            output_filename=output_filename,
            view=args.view,
        )


if __name__ == "__main__":
    _run_from_cli()


def compare_optimizers_loss_landscape(
    num_epochs: int,
    optimizer_names: tuple[str, ...] = ("adam", "sgd", "sgd+momentum"),
    datamodule: pl.LightningDataModule | None = None,
    learning_rate: float = 5e-3,
    model_dirpath: str = "checkpoints/",
    gpus: int = 0,
    seed: int | None = None,
    giffps: int = 15,
    num_frames: int = 300,
    output_to_file: bool = True,
    output_filename: str = "optimizers_multi.gif",
    model_factory=None,
    view: str = "3d",
    elev: int = 50,
    azim: int = -40,
):
    """
    Compare multiple optimizers on the same loss landscape slice.

    Parameters
    ----------
    num_epochs : int
        Number of epochs to train each optimizer.
    optimizer_names : tuple of str, optional
        Tuple/list of optimizer names, e.g.
        ``("adam", "sgd", "sgd+momentum")``. The first optimizer is
        used as the reference for PCA.
    datamodule : LightningDataModule or None, optional
        PyTorch Lightning data module. If None, defaults to
        :class:`SpiralsDataModule`.
    learning_rate : float, optional
        Learning rate shared by all optimizers (used if
        model_factory is None). Default is 5e-3.
    model_dirpath : str, optional
        Directory to save models (not strictly needed for comparison,
        but kept for consistency). Default is "checkpoints/".
    gpus : int, optional
        Number of GPUs if available. Default is 0.
    seed : int or None, optional
        Seed for reproducible experiments. If None, no seeding is
        applied.
    giffps : int, optional
        Frames per second for the GIF. Default is 15.
    num_frames : int, optional
        Maximum number of frames to sample from each optimizer path.
        Default is 300.
    output_to_file : bool, optional
        Whether to write the GIF to file. Default is True.
    output_filename : str, optional
        Output filename for the multi-optimizer GIF. Default is
        "optimizers_multi.gif".
    model_factory : callable or None, optional
        Callable taking a single optimizer name and returning a new,
        untrained :class:`LightningModule` with that optimizer
        configured. If None, a default MLP is used.
    view : {"3d"}, optional
        Currently only "3d" is supported (3D surface animation).
    elev : int, optional
        Elevation angle for the 3D plot (only used if view="3d").
    azim : int, optional
        Azimuth angle for the 3D plot (only used if view="3d").

    Returns
    -------
    None
        The function is used for its side effect of creating an
        animation (GIF) on disk.
    """
    set_global_seed(seed)

    if datamodule is None:
        print("Data module not provided, using sample data: spirals dataset")
        datamodule = SpiralsDataModule()

    # Create a base model and save its initial weights so each optimizer starts
    # from exactly the same point in parameter space.
    if model_factory is not None:
        base_model = model_factory(optimizer=optimizer_names[0])
    else:
        base_model = MLP(
            input_dim=datamodule.input_dim,
            num_classes=datamodule.num_classes,
            learning_rate=learning_rate,
            optimizer=optimizer_names[0],
            gpus=gpus,
        )
    init_state_dict = base_model.state_dict()

    optimizer_paths = {}
    optimizer_loss_paths = {}
    ref_optimizer_name = optimizer_names[0]
    ref_model = None

    model_dir = pathlib.Path(model_dirpath)
    if not model_dir.is_dir():
        model_dir.mkdir(parents=True, exist_ok=True)

    for opt_name in optimizer_names:
        print(f"\n=== Training with optimizer: {opt_name} ===")
        if model_factory is not None:
            model = model_factory(optimizer=opt_name)
        else:
            model = MLP(
                input_dim=datamodule.input_dim,
                num_classes=datamodule.num_classes,
                learning_rate=learning_rate,
                optimizer=opt_name,
                gpus=gpus,
            )
        model.load_state_dict(init_state_dict)

        if gpus > 0:
            print("======== Using GPU for training ========")

        model.gpus = gpus
        train_loader = datamodule.train_dataloader()
        trainer = pl.Trainer(
            enable_progress_bar=True, max_epochs=num_epochs, gradient_clip_val=1.0
        )
        print(f"Training for {num_epochs} epochs with {opt_name}...")
        trainer.fit(model, train_loader)

        sampled_optim_path = sample_frames(model.optim_path, max_frames=num_frames)
        optim_path, loss_path, _ = zip(
            *[
                (path["flattened_params"], path["loss"], path["accuracy"])
                for path in sampled_optim_path
            ]
        )

        optimizer_paths[opt_name] = list(optim_path)
        optimizer_loss_paths[opt_name] = list(loss_path)

        if opt_name == ref_optimizer_name:
            ref_model = model

    # Use the reference optimizer's path to define the 2D PCA plane.
    ref_optim_path = optimizer_paths[ref_optimizer_name]
    print(f"\n# sampled steps in optimization path ({ref_optimizer_name}): {len(ref_optim_path)}")
    print("Dimensionality reduction via PCA (reference optimizer)")
    dim_reduction = DimensionalityReduction(params_path=ref_optim_path, seed=seed)
    reduced_dict = dim_reduction.reduce()
    ref_path_2d = reduced_dict["path_2d"]
    directions = reduced_dict["reduced_dirs"]  # shape (2, D)
    pcvariances = reduced_dict.get("pcvariances")

    # Build a common loss grid centered on the reference optimizer's trajectory.
    loss_grid = LossGrid(
        optim_path=ref_optim_path,
        model=ref_model,
        data=datamodule.dataset.tensors,
        path_2d=ref_path_2d,
        directions=directions,
    )

    # Project each optimizer's path into the same 2D plane.
    optimizer_paths_2d = {}
    for opt_name, path in optimizer_paths.items():
        if opt_name == ref_optimizer_name:
            optimizer_paths_2d[opt_name] = ref_path_2d
        else:
            npvectors = [np.array(t.cpu()) for t in path]
            path_matrix = np.vstack(npvectors)  # (T, D)
            # directions: (2, D) -> transpose to (D, 2)
            path_2d = path_matrix.dot(directions.T)
            optimizer_paths_2d[opt_name] = path_2d

    if view != "3d":
        raise ValueError("Only '3d' view is currently supported.")

    # Animate all optimizers together in 3D.
    animate_loss_surface_multi_3d(
        optimizer_paths_2d=optimizer_paths_2d,
        optimizer_loss_paths=optimizer_loss_paths,
        optimizer_names=list(optimizer_names),
        loss_grid=loss_grid.loss_values_log_2d,
        coords=loss_grid.coords,
        pcvariances=pcvariances,
        giffps=giffps,
        sampling=False,
        max_frames=num_frames,
        elev=elev,
        azim=azim,
        figsize=(9, 6),
        output_to_file=output_to_file,
        filename=output_filename,
    )