import warnings
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import griddata

from matplotlib.animation import FuncAnimation

def animate_loss_surface_3d(
    param_steps,
    loss_steps,
    loss_grid,
    coords,
    pcvariances=None,
    giffps=15,
    sampling=False,
    max_frames=300,
    elev=30,
    azim=-60,
    figsize=(9, 6),
    output_to_file=True,
    filename="loss_surface_3d.gif",
):
    """
    Animate a 3D loss surface with optimization trajectory.

    Arguments:
        param_steps: 2D trajectory in PCA space, shape (T, 2).
        loss_steps: Loss values along the trajectory, length T.
        loss_grid: 2D loss values over the PCA grid (same as used for contour).
        coords: (coords_x, coords_y) from LossGrid.
        pcvariances (optional): PCA variance ratios for axis labels.
        giffps (optional): Frames per second in the output.
        sampling (optional): Whether to sample from the steps.
        max_frames (optional): Max number of frames if sampling.
        elev, azim (optional): View angles for the 3D plot.
        figsize (optional): Figure size.
        output_to_file (optional): Whether to write to file.
        filename (optional): Output filename.
    """
    if sampling:
        print(f"\nSampling {max_frames} from {len(param_steps)} input frames.")
        param_steps = sample_frames(param_steps, max_frames)
        loss_steps = sample_frames(loss_steps, max_frames)

    num_frames = len(param_steps)
    print(f"\nTotal frames to process: {num_frames}, result frames per second: {giffps}")

    coords_x, coords_y = coords
    X, Y = np.meshgrid(coords_x, coords_y)
    Z = loss_grid

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection="3d")

    # Static surface
    surface = ax.plot_surface(
        X,
        Y,
        Z,
        cmap="YlGnBu",
        linewidth=0,
        antialiased=True,
        alpha=0.6,
    )
    fig.colorbar(surface, ax=ax, shrink=0.5, aspect=10, label="log loss")

    # Prepare trajectory data
    xs_all = [W[0] for W in param_steps]
    ys_all = [W[1] for W in param_steps]
    
    # Interpolate trajectory z-values from the loss grid so the trajectory lies on the surface
    # Create grid points for interpolation
    grid_points = np.column_stack([X.ravel(), Y.ravel()])
    grid_values = Z.ravel()
    trajectory_points = np.column_stack([xs_all, ys_all])
    
    # Interpolate z-values from the grid at trajectory coordinates
    zs_all = griddata(
        grid_points, grid_values, trajectory_points, 
        method='linear', fill_value=np.nan
    )
    
    # Handle any points outside the grid (use nearest neighbor for extrapolation)
    nan_mask = np.isnan(zs_all)
    if nan_mask.any():
        zs_all[nan_mask] = griddata(
            grid_points, grid_values, trajectory_points[nan_mask],
            method='nearest'
        )
    
    # Raw loss values for display (convert back from log if needed)
    # Since loss_grid is log-loss, we need to exponentiate for display
    eps = 1e-12
    zs_all_raw = [np.exp(max(z, np.log(eps))) for z in zs_all]
    
    # Guard: all trajectory losses must be finite
    if not np.isfinite(zs_all).all():
        bad_idxs = [i for i, z in enumerate(zs_all) if not np.isfinite(z)]
        raise RuntimeError(
            "Non-finite loss values detected in animate_loss_surface_3d at steps "
            f"{bad_idxs[:10]} (showing up to 10). This usually means trajectory "
            "points are outside the grid bounds or interpolation failed."
        )

    # Initialize with the first point
    xs = [xs_all[0]]
    ys = [ys_all[0]]
    zs = [zs_all[0]]

    (pathline,) = ax.plot(
        xs,
        ys,
        zs,
        color="red",
        linewidth=1,
        label="trajectory",
        zorder=10,
    )
    # Moving marker to highlight current position without adding any extra
    # connecting lines (the line is entirely handled by `pathline`).
    (point,) = ax.plot(
        [xs_all[0]],
        [ys_all[0]],
        [zs_all[0]],
        "ko",
        markersize=4,
        zorder=11,
    )

    xlabel_text = "Principal Component 0"
    ylabel_text = "Principal Component 1"
    if pcvariances is not None:
        xlabel_text = f"principal component 0, {pcvariances[0]:.1%}"
        ylabel_text = f"principal component 1, {pcvariances[1]:.1%}"

    ax.set_xlabel(xlabel_text)
    ax.set_ylabel(ylabel_text)
    ax.set_zlabel("loss")
    ax.view_init(elev=elev, azim=azim)
    ax.dist = 5  # move camera closer so the surface occupies more of the frame
    ax.legend(loc="upper right")

    step_text = ax.text2D(
        0.05, 0.9, "", fontsize=10, ha="left", va="center", transform=ax.transAxes
    )
    value_text = ax.text2D(
        0.05, 0.82, "", fontsize=10, ha="left", va="center", transform=ax.transAxes
    )

    def animate(i):
        xs.append(xs_all[i])
        ys.append(ys_all[i])
        zs.append(zs_all[i])

        # Update path line
        pathline.set_data(xs, ys)
        pathline.set_3d_properties(zs)

        # Update current point (in log-loss space for z)
        point.set_data([xs_all[i]], [ys_all[i]])
        point.set_3d_properties([zs_all[i]])

        step_text.set_text(f"step: {i}")
        # Show raw loss to the user
        value_text.set_text(f"loss: {zs_all_raw[i]: .3f}")

        return pathline, point, step_text, value_text

    global anim
    anim = FuncAnimation(
        fig,
        animate,
        frames=num_frames,
        interval=100,
        blit=False,
        repeat=False,
    )

    if output_to_file:
        print(f"Writing {filename}.")
        anim.save(
            f"./{filename}",
            writer="imagemagick",
            fps=giffps,
            progress_callback=_animate_progress,
        )
        print(f"\n{filename} created successfully.")
    else:
        plt.ioff()
        plt.show()


def animate_loss_surface_multi_3d(
    optimizer_paths_2d,
    optimizer_loss_paths,
    optimizer_names,
    loss_grid,
    coords,
    pcvariances=None,
    giffps=15,
    sampling=False,
    max_frames=300,
    elev=30,
    azim=-60,
    figsize=(9, 6),
    output_to_file=True,
    filename="optimizers_multi_3d.gif",
):
    """
    Animate a 3D loss surface with multiple optimizers' trajectories.

    Arguments:
        optimizer_paths_2d: dict mapping optimizer name -> 2D path (array of shape (T, 2)).
        optimizer_loss_paths: dict mapping optimizer name -> list of loss values (length T).
        optimizer_names: list of optimizer names (defines colors / legend order).
        loss_grid: 2D loss values over the PCA grid (typically log-loss).
        coords: (coords_x, coords_y) from LossGrid.
        pcvariances: PCA variance ratios for axis labels.
        giffps: Frames per second for the GIF.
        sampling: Whether to sample from the steps (applied per optimizer).
        max_frames: Max frames if sampling.
        elev, azim: View angles for the 3D plot.
        figsize: Figure size.
        output_to_file: Whether to write GIF to disk.
        filename: Output filename.
    """
    if sampling:
        for name in optimizer_names:
            optimizer_paths_2d[name] = sample_frames(
                optimizer_paths_2d[name], max_frames
            )
            optimizer_loss_paths[name] = sample_frames(
                optimizer_loss_paths[name], max_frames
            )

    # Determine maximum trajectory length
    max_len = max(len(optimizer_paths_2d[name]) for name in optimizer_names)
    print(f"\nTotal frames to process: {max_len}, result frames per second: {giffps}")

    coords_x, coords_y = coords
    X, Y = np.meshgrid(coords_x, coords_y)
    Z = loss_grid

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection="3d")

    # Static surface
    surface = ax.plot_surface(
        X,
        Y,
        Z,
        cmap="YlGnBu",
        linewidth=0,
        antialiased=True,
        alpha=0.6,
    )
    fig.colorbar(surface, ax=ax, shrink=0.5, aspect=10, label="log loss")

    # Create grid points for interpolation (shared across all optimizers)
    grid_points = np.column_stack([X.ravel(), Y.ravel()])
    grid_values = Z.ravel()

    # Colors per optimizer
    default_colors = ["r", "g", "y", "m", "c", "b"]

    # Prepare per-optimizer data and artists
    lines = {}
    line_coords = {}
    points = {}
    loss_raw = {}

    for idx, opt_name in enumerate(optimizer_names):
        path_2d = optimizer_paths_2d[opt_name]
        losses = optimizer_loss_paths[opt_name]
        xs_all = [W[0] for W in path_2d]
        ys_all = [W[1] for W in path_2d]
        
        # Interpolate trajectory z-values from the loss grid
        trajectory_points = np.column_stack([xs_all, ys_all])
        zs_all = griddata(
            grid_points, grid_values, trajectory_points,
            method='linear', fill_value=np.nan
        )
        
        # Handle any points outside the grid
        nan_mask = np.isnan(zs_all)
        if nan_mask.any():
            zs_all[nan_mask] = griddata(
                grid_points, grid_values, trajectory_points[nan_mask],
                method='nearest'
            )
        
        # Raw loss values for display (exponentiate from log)
        eps = 1e-12
        losses_raw = [np.exp(max(z, np.log(eps))) for z in zs_all]

        # Guard: all trajectory losses must be finite for this optimizer
        if not np.isfinite(zs_all).all():
            bad_idxs = [i for i, z in enumerate(zs_all) if not np.isfinite(z)]
            raise RuntimeError(
                "Non-finite loss values detected in animate_loss_surface_multi_3d "
                f"for optimizer '{opt_name}' at steps {bad_idxs[:10]} "
                "(showing up to 10). This usually means trajectory points are "
                "outside the grid bounds or interpolation failed."
            )

        loss_raw[opt_name] = (losses_raw, zs_all, xs_all, ys_all)

        xs = [xs_all[0]]
        ys = [ys_all[0]]
        zs = [zs_all[0]]

        color = default_colors[idx % len(default_colors)]
        (line,) = ax.plot(
            xs,
            ys,
            zs,
            color=color,
            linewidth=1,
            label=opt_name,
            zorder=10,
        )
        (point,) = ax.plot(
            [xs_all[0]],
            [ys_all[0]],
            [zs_all[0]],
            marker="o",
            color=color,
            markersize=4,
            zorder=11,
        )

        lines[opt_name] = line
        line_coords[opt_name] = (xs, ys, zs)
        points[opt_name] = point

    xlabel_text = "Principal Component 0"
    ylabel_text = "Principal Component 1"
    if pcvariances is not None:
        xlabel_text = f"principal component 0, {pcvariances[0]:.1%}"
        ylabel_text = f"principal component 1, {pcvariances[1]:.1%}"

    ax.set_xlabel(xlabel_text)
    ax.set_ylabel(ylabel_text)
    ax.set_zlabel("log loss")
    ax.view_init(elev=elev, azim=azim)
    ax.dist = 5
    ax.legend(loc="upper right")

    step_text = ax.text2D(
        0.05, 0.9, "", fontsize=10, ha="left", va="center", transform=ax.transAxes
    )
    value_text = ax.text2D(
        0.05, 0.82, "", fontsize=10, ha="left", va="center", transform=ax.transAxes
    )

    def animate(i):
        texts = []
        info_lines = [f"step: {i}"]
        for opt_name in optimizer_names:
            xs, ys, zs = line_coords[opt_name]
            line = lines[opt_name]
            point = points[opt_name]
            losses_raw, zs_all, xs_all, ys_all = loss_raw[opt_name]

            if i < len(xs_all):
                xs.append(xs_all[i])
                ys.append(ys_all[i])
                zs.append(zs_all[i])

                line.set_data(xs, ys)
                line.set_3d_properties(zs)

                point.set_data([xs_all[i]], [ys_all[i]])
                point.set_3d_properties([zs_all[i]])

                info_lines.append(f"{opt_name}: loss={losses_raw[i]:.3f}")

        step_text.set_text(info_lines[0])
        value_text.set_text("\n".join(info_lines[1:]))

        return list(lines.values()) + list(points.values()) + [step_text, value_text]

    global anim
    anim = FuncAnimation(
        fig,
        animate,
        frames=max_len,
        interval=100,
        blit=False,
        repeat=False,
    )

    if output_to_file:
        print(f"Writing {filename}.")
        anim.save(
            f"./{filename}",
            writer="imagemagick",
            fps=giffps,
            progress_callback=_animate_progress,
        )
        print(f"\n{filename} created successfully.")
    else:
        plt.ioff()
        plt.show()


def _animate_progress(current_frame, total_frames):
    print("\r" + f"Processing {current_frame+1}/{total_frames} frames...", end="")
    if current_frame + 1 == total_frames:
        print("\nConverting to gif, this may take a while...")


def sample_frames(steps, max_frames):
    """
    Sample frames from the provided sequence, starting from the end and
    stepping backwards, then return them in chronological order.

    This is useful when we have many training steps but only want a
    smaller number of frames for animation, with finer resolution near
    the end of training.

    Arguments:
        steps: The frames to sample from (any indexable sequence).
        max_frames: Maximum number of frames to sample.
    """
    samples = []
    steps_len = len(steps)
    if max_frames > steps_len:
        warnings.warn(
            f"Less than {max_frames} frames provided, producing {steps_len} frames."
        )
        max_frames = steps_len
    interval = steps_len // max_frames
    counter = 0
    for i in range(steps_len - 1, -1, -1):
        if i % interval == 0 and counter < max_frames:
            samples.append(steps[i])
            counter += 1
    return list(reversed(samples))
