import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from scipy.integrate import solve_ivp
from scipy.interpolate import interp1d
from pycbc.waveform import get_td_waveform


def keplerODE(t, y, m1, m2, G, friction):
    r1, v1 = y[0:2], y[2:4]
    r2, v2 = y[4:6], y[6:8]
    
    r = r1 - r2
    dist = np.linalg.norm(r) + 1.0 # Epsilon to prevent div/0
    
    # Gravitational acceleration (Astrophysical scale)
    a_grav1 = -G * m2 * r / dist**3
    a_grav2 = G * m1 * r / dist**3
    
    # Drag term to force inspiral
    a_drag1 = -friction * v1
    a_drag2 = -friction * v2

    dy = np.zeros_like(y)
    dy[0:2] = v1; dy[2:4] = a_grav1 + a_drag1
    dy[4:6] = v2; dy[6:8] = a_grav2 + a_drag2
    return dy

def get_spiral_data(target_duration=0.2, steps=800):
    """
    Simulates a binary black hole merger (30 Solar Masses each)
    tuned to merge in exactly 'target_duration' seconds.
    """
    G = 6.67430e-11
    M_solar = 1.989e30
    
    m1 = 30 * M_solar
    m2 = 30 * M_solar
    
    # Initial separation: 700km is a good start for a ~0.2s merger
    R = 700000.0 # meters
    
    # Initial Conditions (Circular Orbit)
    r1_loc = np.array([0.0, 0.0]); r2_loc = np.array([R, 0.0])
    com = (m1 * r1_loc + m2 * r2_loc) / (m1 + m2)
    r1_loc -= com; r2_loc -= com
    
    Omega = np.sqrt(G * (m1 + m2) / R**3)
    v1_loc = Omega * np.array([-r1_loc[1], r1_loc[0]])
    v2_loc = Omega * np.array([-r2_loc[1], r2_loc[0]])
    
    y0 = np.concatenate([r1_loc, v1_loc, r2_loc, v2_loc])
    
    # Tune friction to force merger at target_duration
    friction = 2.5 / target_duration
    
    t_eval = np.linspace(0, target_duration, steps)

    def collision_event(t, y): 
        # Stop when distance < 10km (effectively merged)
        return np.linalg.norm(y[0:2] - y[4:6]) - 10000.0
    collision_event.terminal = True

    sol = solve_ivp(lambda t, y: keplerODE(t, y, m1, m2, G, friction), 
                    (0, target_duration * 1.2), y0, t_eval=t_eval, events=collision_event, rtol=1e-9)

    # Return data in Kilometers (km)
    return sol.t, sol.y.T[:, 0:2]/1000.0, sol.y.T[:, 4:6]/1000.0

# this will create the waveform that we need
def get_waveform_data(duration_needed):
    # Generate waveform for 30 Solar Mass binary
    hp, hc = get_td_waveform(approximant="SEOBNRv4",
                             mass1=30, mass2=30,
                             delta_t=1.0/4096,
                             f_lower=20)
    
    times = hp.sample_times
    strain = hp.numpy()
    
    # Slice to match the exact duration of the simulation
    # We take the window [Merger - duration, Merger + small buffer]
    mask = (times > -duration_needed) & (times < 0.02)
    return times[mask], strain[mask]


def run_unified_demo():
    orbit_t, r1s, r2s = get_spiral_data(target_duration=0.2, steps=600)
    
    sim_duration = orbit_t[-1]

    gw_t, gw_strain = get_waveform_data(duration_needed=sim_duration)

    mapped_gw_time = np.linspace(gw_t[0], gw_t[-1], len(orbit_t))
    interp_func = interp1d(gw_t, gw_strain, kind='cubic', fill_value="extrapolate")
    synced_strain = interp_func(mapped_gw_time)
    
 
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    plt.subplots_adjust(wspace=0.3)
    
    # creating the window for the orbits of the planets
    ax1.set_facecolor('black')
    limit = 500 # 500 km range
    ax1.set_xlim(-limit, limit); ax1.set_ylim(-limit, limit)
    ax1.set_aspect('equal')
    ax1.set_title("Binary Simulation (Astrophysical)", fontsize=14)
    ax1.set_xlabel("Distance (km)")
    ax1.set_ylabel("Distance (km)")
    ax1.grid()
    
    line1, = ax1.plot([], [], color='cyan', lw=1.5, alpha=0.8)
    line2, = ax1.plot([], [], color='orange', lw=1.5, alpha=0.8)
    body1, = ax1.plot([], [], 'o', color='cyan', markersize=12, markeredgecolor='white')
    body2, = ax1.plot([], [], 'o', color='orange', markersize=12, markeredgecolor='white')

    # Merger Indicators
    flash, = ax1.plot([], [], '*', color='white', ms=40, markeredgecolor='#ffdd00', markeredgewidth=2, zorder=20)
    merger_text = ax1.text(0, 300, '', color='white', fontsize=16, fontweight='bold', ha='center', va='center', zorder=21)

    # Using the waveform data to present the wave
    ax2.set_title("Gravitational Wave Strain (PyCBC)", fontsize=14)
    ax2.set_xlabel("Time to Merger (s)")
    ax2.set_ylabel("Strain")
    ax2.grid(True, alpha=0.3, ls='--')
    
    ax2.set_xlim(mapped_gw_time[0], mapped_gw_time[-1] + 0.05)
    ax2.set_ylim(np.min(synced_strain)*1.2, np.max(synced_strain)*1.2)

    wave_line, = ax2.plot([], [], color='#ff3366', lw=2)
    wave_head, = ax2.plot([], [], 'o', color='#ff3366', markersize=6)

    def init():
        line1.set_data([],[]); line2.set_data([],[])
        body1.set_data([],[]); body2.set_data([],[])
        wave_line.set_data([],[]); wave_head.set_data([],[])
        flash.set_data([],[]); merger_text.set_text('')
        return line1, line2, body1, body2, wave_line, wave_head, flash, merger_text

    # Hold Frames for Merger Visibility
    Data_len = len(orbit_t)
    hold_frames = 30
    total_frames = Data_len + hold_frames

    def update(i):
        idx = min(i, Data_len - 10)
        
        # 1. Update Orbit
        trail = 100
        start = max(0, idx - trail)
        line1.set_data(r1s[start:idx, 0], r1s[start:idx, 1])
        line2.set_data(r2s[start:idx, 0], r2s[start:idx, 1])
        
        # Show flash if we are at the end of data OR in hold frames
        is_merged = (i >= Data_len - 50)
        
        if is_merged:
            body1.set_data([], []); body2.set_data([], [])
            flash.set_data([0], [0])
            merger_text.set_text('MERGED')
        else:
            body1.set_data([r1s[idx, 0]], [r1s[idx, 1]])
            body2.set_data([r2s[idx, 0]], [r2s[idx, 1]])
            flash.set_data([], []); merger_text.set_text('')
        
        wave_line.set_data(mapped_gw_time[:idx], synced_strain[:idx])
        wave_head.set_data([mapped_gw_time[idx]], [synced_strain[idx]])
            
        return line1, line2, body1, body2, wave_line, wave_head, flash, merger_text

    ani = animation.FuncAnimation(fig, update, frames=total_frames, init_func=init, 
                                  interval=25, blit=True)

    print("Saving Animation...")

    gif_filename = 'Generating_Chir.gif'
    gif_writer = animation.PillowWriter(fps=30)
    ani.save(gif_filename, writer=gif_writer, dpi=100)
    print(f"Success: Saved to '{gif_filename}'")
   
    writer = animation.FFMpegWriter(fps=30)
    ani.save("Generate_Chirp_Updated.mp4", writer=writer)
    print("Success! Saved to 'Generate_Chirp_Updated.mp4'")
    

if __name__ == "__main__":
    run_unified_demo()