import jax
import jax.numpy as jnp
import flax.linen as nn
from functools import partial
import optax
from einops import rearrange

bsz = 26
img_dim = 128
rng = jax.random.PRNGKey(1)
flow_steps = 32

patchify = lambda x : rearrange(x, 'b c (h d1) (w d2) -> b (h w) (c d1 d2)', d1=16, d2=16)
unpatchify = lambda x : rearrange(x, 'b (h w) (c d1 d2) -> b c (h d1) (w d2)', d1=16, d2=16, h=img_dim//16, w=img_dim//16)
ex_batch = jnp.zeros((bsz, 3, img_dim, img_dim))
ex_t = jnp.zeros((bsz, 1, 1))
model = FlowTransformer(heads=8, out=768) # from big vision
grad_params = model.init(rng, patchify(ex_batch), t=ex_t)
opt = optax.adamw(learning_rate=1e-3)
opt_state = opt.init(grad_params)

@jax.jit
def train_step(params, opts, batch, rng):
    rng, s_rng = jax.random.split(rng)
    def loss_fn(params, rng):
        n_rng, t_rng = jax.random.split(rng)
        x_0 = jax.random.normal(n_rng, batch.shape)
        x_1 = batch
        vel = x_1 - x_0
        t = jax.random.uniform(t_rng, (batch.shape[0], 1, 1))
        x_t = t * x_1 + (1 - t) * x_0
        pred = model.apply(params, x_t, t=t)
        return jnp.square((pred - vel)).mean()
    loss, grads = jax.value_and_grad(loss_fn)(params, s_rng)
    updates, opts = opt.update(grads, opts, params)
    params = optax.apply_updates(params, updates)
    return params, opts, loss


import tensorflow as tf
import cv2
import numpy as np
import time

# @jax.jit
def sample(params, rng):
    frames = cv2.VideoWriter(str(time.time()).replace(".", "_") + 'anim.avi', cv2.VideoWriter_fourcc(*'XVID'), 1, (img_dim,)*2)
    rng, n_rng = jax.random.split(rng)
    x = patchify(jax.random.normal(n_rng, (1,) + ex_batch.shape[1:]))
    t = jnp.zeros((1,) + ex_t.shape[1:])
    for fs in range(flow_steps):
        pred = model.apply(params, x, t=t)[0]
        x = x + pred / flow_steps
        t = t + 1 / flow_steps

        frames.write((np.array(rearrange(jnp.clip(unpatchify(x)[0], 0, 1), 'c h w -> h w c')) * 255).astype(np.uint8))
    frames.release()
    return jnp.clip(unpatchify(x), 0, 1)


dataset = tf.keras.preprocessing.image_dataset_from_directory('imgs/', image_size=(img_dim,)*2, batch_size=bsz, shuffle=True, labels=None).repeat()
for step, batch in enumerate(dataset):
    rng, train_rng = jax.random.split(rng)
    batch = rearrange(jnp.array(batch.numpy()) / 255, 'b h w c -> b c h w')
    batch = patchify(batch)
    grad_params, opt_state, loss = train_step(grad_params, opt_state, batch, train_rng)
    print(step, loss)
    if step % 1000 == 0:
        for j in range(10):
            rng, s_rng = jax.random.split(rng)
            render = np.array(rearrange(sample(grad_params, s_rng), 'b c h w -> h w (b c)'))
            cv2.imwrite(str(step) + "_" + str(time.time()).replace(".", "_") + ".png", (render * 255).astype(np.uint8))
            cv2.imshow("win", render)
            cv2.waitKey(10000)


