In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.animation import FFMpegWriter

# Constants
alpha = 2.3e-5 # thermal diffusivity of iron m**2/sec
dt = 0.1  # Time step for simulation in seconds
N = 200 
L = 20 #meters
step = L/N #0.1 meters
height = 20 #meters
radius = 1 #meters
center = np.array([L//2, L//2])
bottom_temp = 700 #temperature of open flame on bottom of rod

def createInitialConditions(N, L, height, r, center):
    step = L/N
    initial = np.zeros((N,N, height))
    
    xb = []
    yb = []
    zb = []
    tb = []
    
    #rest arrays
    xr = []
    yr = []
    zr = []
    tr = []
    for i in range(N):
        for j in range(N):
            for k in range(height):
                point = np.array([i*step, j*step])
                if (np.linalg.norm(center - point) <= r):
                    xb.append(i*step)
                    yb.append(j*step)
                    zb.append(k)
                    if k>=1:
                        t = 25
                        initial[i][j][k] = t
                        tb.append(t)
                    else:
                        t = bottom_temp 
                        initial[i][j][k] = t
                        tb.append(t)
                        
                else:
                    xr.append(i*step)
                    yr.append(j*step)
                    zr.append(k)
                    
                    t = 25
                    tr.append(t)
    positionArrays = [xb, yb, zb, tb,  xr, yr, zr, tr]
    return positionArrays, initial

def classifyPoints(initial, step):   
    xb = []
    yb = []
    zb = []
    tb = []
    
    #rest arrays
    xr = []
    yr = []
    zr = []
    tr = []
    for i in range(N):
        for j in range(N):
            for k in range(height):
                point = np.array([i*step, j*step])
                if (np.linalg.norm(center - point) <= radius):
                    xb.append(i*step)
                    yb.append(j*step)
                    zb.append(k)
                    
                    t = initial[i][j][k]
                    tb.append(t)
                else:
                    xr.append(i*step)
                    yr.append(j*step)
                    zr.append(k)
                    
                    t = initial[i][j][k]
                    tr.append(t)
    positionArrays = [xb, yb, zb, tb,  xr, yr, zr, tr]
    return positionArrays

def fourier_heat(t_old, N, L, height, dt, alpha):
    # Discrete Fourier Transform
    T_freq = np.fft.fftn(t_old)

    # Create frequency indices
    kx = np.fft.fftfreq(N, d=L/N) * 2 * np.pi
    ky = np.fft.fftfreq(N, d=L/N) * 2 * np.pi
    kz = np.fft.fftfreq(height, d=L/N) * 2 * np.pi
    kx, ky, kz = np.meshgrid(kx, ky, kz, indexing='ij')

    # Solve heat equation in frequency domain
    decay_factor = np.exp(-alpha * (kx**2 + ky**2 + kz**2) * dt)
    T_freq *= decay_factor

    # Inverse Discrete Fourier Transform
    t_new = np.fft.ifftn(T_freq).real
    
    # Continuos heating of the bottom
    t_new[:, :, 0] = bottom_temp
    return t_new

# Initialize
positionArrays, tInit = createInitialConditions(N, L, height, radius, center)

# Set up animation
fig = plt.figure(dpi=200)
ax = fig.add_subplot(projection='3d')
ax.set_xlim(0, L)
ax.set_ylim(0, L)
ax.set_zlim(0, height + 5)
metadata = dict(title='4D Cylinder Heated', artist='Mia', comment='Final Project')
writer = FFMpegWriter(fps=15, metadata=metadata, bitrate=200000)

number_of_frames = 100
with writer.saving(fig, "4Dheat.mp4", dpi=200):
    for frame in range(number_of_frames):
        fig.clear()
        ax = fig.add_subplot(projection='3d')
        ax.set_xlim(0, L)
        ax.set_ylim(0, L)
        ax.set_zlim(0, height + 5)

        # Update temperature using Fourier method
        second = range(50) # 5 seconds
        for i in second:
            tInit = fourier_heat(tInit, N, L, height, dt, alpha)
        positionArrays = classifyPoints(tInit, step)

        # Plotting
        img = ax.scatter(positionArrays[0], positionArrays[1], positionArrays[2], c=positionArrays[3], cmap='hot', alpha=0.8)
        fig.colorbar(img)
        plt.pause(0.03)
        writer.grab_frame()