# borrowed from https://github.com/TencentARC/MotionCtrl/blob/main/app.py

import argparse
import gc
import os
import tempfile
import threading
import time

import cv2
import gradio as gr
import imageio
import numpy as np
import torch


tempfile_dir = "/tmp/Tora"
os.makedirs(tempfile_dir, exist_ok=True)
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
SPACE_ID = os.environ.get("SPACE_ID", "")
os.system("modelscope download --model=xiaoche/Tora --local_dir ./ckpts")

#### Description ####
title = r"""<h1 align="center">Tora: Trajectory-oriented Diffusion Transformer for Video Generation</h1>"""

description = r""""""
article = r"""
---

📝 **Citation**
<br>
```bibtex
@misc{zhang2024toratrajectoryorienteddiffusiontransformer,
      title={Tora: Trajectory-oriented Diffusion Transformer for Video Generation},
      author={Zhenghao Zhang and Junchao Liao and Menghao Li and Zuozhuo Dai and Bingxue Qiu and Siyu Zhu and Long Qin and Weizhi Wang},
      year={2024},
      eprint={2407.21705},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2407.21705},
}
```
"""
css = """
.gradio-container {width: 85% !important}
.gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important;}
span.svelte-s1r2yt {font-size: 17px !important; font-weight: bold !important; color: #d30f2f !important;}
button {border-radius: 8px !important;}
.add_button {background-color: #4CAF50 !important;}
.remove_button {background-color: #f44336 !important;}
.clear_button {background-color: gray !important;}
.mask_button_group {gap: 10px !important;}
.video {height: 300px !important;}
.image {height: 300px !important;}
.video .wrap.svelte-lcpz3o {display: flex !important; align-items: center !important; justify-content: center !important;}
.video .wrap.svelte-lcpz3o > :first-child {height: 100% !important;}
.margin_center {width: 50% !important; margin: auto !important;}
.jc_center {justify-content: center !important;}
"""

traj_list = []
traj_list_range_256 = []

canvas_width, canvas_height = 256, 256


# Note that the coordinates passed to the model must not exceed 256.
# xy range 256
PROVIDED_TRAJS = {
    "circle": [
        [120, 194],
        [144, 193],
        [155, 189],
        [158, 170],
        [160, 153],
        [159, 123],
        [152, 113],
        [136, 100],
        [124, 100],
        [108, 100],
        [101, 106],
        [90, 110],
        [84, 129],
        [79, 146],
        [78, 165],
        [83, 182],
        [87, 189],
        [94, 192],
        [100, 194],
        [106, 194],
        [112, 194],
        [118, 195],
    ],
    "spiral": [
        [100, 127],
        [105, 117],
        [122, 117],
        [132, 129],
        [133, 158],
        [125, 181],
        [108, 189],
        [92, 185],
        [84, 179],
        [79, 163],
        [75, 142],
        [73, 118],
        [75, 82],
        [91, 63],
        [115, 52],
        [139, 46],
        [154, 55],
        [167, 93],
        [175, 112],
        [177, 137],
        [177, 158],
        [177, 171],
        [175, 188],
        [173, 204],
    ],
    "coaster": [
        [40, 208],
        [40, 148],
        [40, 100],
        [52, 58],
        [60, 57],
        [74, 68],
        [78, 90],
        [84, 123],
        [88, 148],
        [96, 168],
        [100, 181],
        [102, 188],
        [105, 192],
        [113, 118],
        [119, 80],
        [128, 68],
        [145, 109],
        [149, 155],
        [157, 175],
        [161, 184],
        [164, 184],
        [172, 166],
        [183, 107],
        [189, 84],
        [198, 76],
    ],
    "dance": [
        [81, 112],
        [86, 112],
        [92, 112],
        [100, 113],
        [102, 114],
        [97, 115],
        [92, 114],
        [86, 112],
        [81, 112],
        [80, 112],
        [84, 113],
        [89, 114],
        [95, 114],
        [101, 114],
        [102, 114],
        [103, 124],
        [105, 137],
        [109, 156],
        [114, 172],
        [119, 180],
        [124, 184],
        [131, 181],
        [140, 168],
        [146, 152],
        [150, 128],
        [151, 117],
        [152, 116],
        [156, 116],
        [163, 115],
        [169, 116],
        [175, 116],
        [173, 116],
        [167, 116],
        [162, 114],
        [157, 114],
        [152, 115],
        [156, 115],
        [163, 115],
        [168, 115],
        [174, 116],
        [175, 116],
        [168, 116],
        [162, 116],
        [152, 114],
        [149, 134],
        [145, 156],
        [139, 168],
        [130, 183],
        [118, 180],
        [112, 170],
        [107, 151],
        [102, 128],
        [103, 117],
        [96, 113],
        [88, 113],
        [83, 112],
        [80, 112],
    ],
    "infinity": [
        [60, 141],
        [71, 127],
        [92, 120],
        [112, 123],
        [130, 145],
        [145, 163],
        [167, 178],
        [189, 187],
        [206, 176],
        [213, 147],
        [208, 124],
        [190, 112],
        [176, 111],
        [158, 124],
        [145, 147],
        [125, 172],
        [104, 189],
        [72, 189],
        [59, 184],
        [55, 153],
        [57, 140],
        [75, 119],
        [112, 118],
        [129, 142],
        [149, 163],
        [168, 180],
        [194, 186],
        [206, 175],
        [211, 159],
        [212, 149],
        [212, 134],
        [206, 122],
        [180, 112],
        [163, 116],
        [149, 138],
        [128, 170],
        [108, 184],
        [86, 190],
        [63, 181],
        [57, 152],
        [57, 139],
    ],
    "pause": [
        [98, 186],
        [100, 188],
        [98, 186],
        [100, 188],
        [101, 187],
        [104, 187],
        [111, 184],
        [116, 176],
        [125, 162],
        [132, 140],
        [136, 119],
        [137, 104],
        [138, 96],
        [139, 94],
        [140, 94],
        [140, 96],
        [138, 98],
        [138, 96],
        [136, 94],
        [137, 92],
        [140, 92],
        [144, 92],
        [149, 92],
        [152, 92],
        [151, 92],
        [147, 92],
        [142, 92],
        [140, 92],
        [139, 95],
        [139, 105],
        [141, 122],
        [142, 143],
        [140, 167],
        [136, 184],
        [135, 188],
        [132, 195],
        [132, 192],
        [131, 192],
        [131, 192],
        [130, 192],
        [130, 195],
    ],
    "shake": [
        [103, 89],
        [104, 89],
        [106, 89],
        [107, 89],
        [108, 89],
        [109, 89],
        [110, 89],
        [111, 89],
        [112, 89],
        [113, 89],
        [114, 89],
        [115, 89],
        [116, 89],
        [117, 89],
        [118, 89],
        [119, 89],
        [120, 89],
        [122, 89],
        [123, 89],
        [124, 89],
        [125, 89],
        [126, 89],
        [127, 88],
        [128, 88],
        [129, 88],
        [130, 88],
        [131, 88],
        [133, 87],
        [136, 86],
        [137, 86],
        [138, 86],
        [139, 86],
        [140, 86],
        [141, 86],
        [142, 86],
        [143, 86],
        [144, 86],
        [145, 86],
        [146, 87],
        [147, 87],
        [148, 87],
        [149, 87],
        [148, 87],
        [146, 87],
        [145, 88],
        [144, 88],
        [142, 89],
        [141, 89],
        [140, 90],
        [140, 91],
        [138, 91],
        [137, 92],
        [136, 92],
        [136, 93],
        [135, 93],
        [134, 93],
        [133, 93],
        [132, 93],
        [131, 93],
        [130, 93],
        [129, 93],
        [128, 93],
        [127, 92],
        [125, 92],
        [124, 92],
        [123, 92],
        [122, 92],
        [121, 92],
        [120, 92],
        [119, 92],
        [118, 92],
        [117, 92],
        [116, 92],
        [115, 92],
        [113, 92],
        [112, 92],
        [111, 92],
        [110, 92],
        [109, 92],
        [108, 92],
        [108, 91],
        [108, 90],
        [109, 90],
        [110, 90],
        [111, 89],
        [112, 89],
        [113, 89],
        [114, 89],
        [115, 89],
        [115, 88],
        [116, 88],
        [117, 88],
        [118, 88],
        [118, 87],
        [119, 87],
        [120, 87],
        [121, 87],
        [122, 86],
        [123, 86],
        [124, 86],
        [125, 86],
        [126, 85],
        [127, 85],
        [128, 85],
        [129, 85],
        [130, 85],
        [131, 85],
        [132, 85],
        [133, 85],
        [134, 85],
        [135, 85],
        [136, 85],
        [137, 85],
        [138, 85],
        [139, 85],
        [140, 85],
        [141, 85],
        [142, 85],
        [143, 85],
        [143, 84],
        [144, 84],
        [145, 84],
        [146, 84],
        [147, 84],
        [148, 84],
        [149, 84],
        [148, 84],
        [147, 84],
        [145, 84],
        [144, 84],
        [143, 84],
        [142, 84],
        [141, 84],
        [140, 85],
        [139, 85],
        [138, 85],
        [137, 86],
        [136, 86],
        [136, 87],
        [135, 87],
        [134, 87],
        [133, 87],
        [132, 88],
        [131, 88],
        [130, 88],
        [129, 88],
        [129, 89],
        [128, 89],
        [127, 89],
        [126, 89],
        [125, 89],
        [124, 90],
        [123, 90],
        [122, 90],
        [121, 90],
        [120, 91],
        [119, 91],
        [118, 91],
        [117, 91],
        [116, 91],
        [115, 91],
        [114, 91],
        [113, 91],
        [112, 91],
        [111, 91],
        [110, 91],
        [109, 91],
        [109, 90],
        [108, 90],
        [110, 90],
        [111, 90],
        [113, 90],
        [114, 90],
        [115, 90],
        [116, 90],
        [118, 90],
        [120, 90],
        [121, 90],
        [122, 90],
        [123, 90],
        [124, 90],
        [126, 90],
        [127, 90],
        [128, 90],
        [129, 90],
        [130, 90],
        [131, 90],
        [132, 90],
        [133, 90],
        [134, 90],
        [135, 90],
        [136, 90],
        [137, 90],
        [138, 90],
        [139, 90],
        [140, 90],
        [141, 89],
        [142, 89],
        [143, 89],
        [144, 89],
        [145, 89],
        [146, 89],
        [147, 89],
        [147, 89],
        [147, 89],
    ],
    "wave": [
        [16, 152],
        [23, 138],
        [39, 122],
        [54, 115],
        [75, 118],
        [88, 130],
        [93, 150],
        [89, 176],
        [75, 184],
        [63, 177],
        [65, 152],
        [77, 135],
        [98, 121],
        [116, 120],
        [135, 127],
        [148, 136],
        [156, 145],
        [160, 165],
        [158, 176],
        [138, 187],
        [133, 185],
        [129, 148],
        [140, 133],
        [156, 120],
        [177, 118],
        [197, 118],
        [214, 119],
        [225, 118],
    ],
}


PROVIDED_PROMPTS = {
    "dandelion": "A dandelion puff sways gently in the wind, its seeds ready to take flight and spread into the world. The animation style highlights the delicate fibers of the puff, with soft, glowing light surrounding it. The background showcases a lush, green field, hinting at the beauty of nature. As the wind blows, the seeds dance and float away, creating an enchanting visual narrative. The gentle sounds of nature, alongside soft whispers of the breeze, enrich the overall ambiance. This serene scene invites viewers to embrace the moment of letting go, celebrating the cycle of life and new beginnings.",
    "golden retriever": "A golden retriever, sporting sleek black sunglasses, with its lengthy fur flowing in the breeze, sprints playfully across a rooftop terrace, recently refreshed by a light rain. The scene unfolds from a distance, the dog's energetic bounds growing larger as it approaches the camera, its tail wagging with unrestrained joy, while droplets of water glisten on the concrete behind it. The overcast sky provides a dramatic backdrop, emphasizing the vibrant golden coat of the canine as it dashes towards the viewer.",
    "rubber duck": "A cheerful rubber duck floats serenely in a bathtub filled with bubbles, the soft foam creating an inviting atmosphere. The bathroom setting is warm with bright tiles reflecting soft light. The camera captures playful angles, zeroing in on the duck's bright yellow color and big eyes. Sounds of water gently splashing and laughter fill the background, enhancing the joyous ambiance. This moment invites viewers to embrace nostalgia and childhood fun, evoking a sense of playfulness and relaxation.",
    "squirrel": "A squirrel gathering nuts.",
}


#############################
def pdf2(sigma_matrix, grid):
    """Calculate PDF of the bivariate Gaussian distribution.
    Args:
        sigma_matrix (ndarray): with the shape (2, 2)
        grid (ndarray): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size.
    Returns:
        kernel (ndarrray): un-normalized kernel.
    """
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
    return kernel


def mesh_grid(kernel_size):
    """Generate the mesh grid, centering at zero.
    Args:
        kernel_size (int):
    Returns:
        xy (ndarray): with the shape (kernel_size, kernel_size, 2)
        xx (ndarray): with the shape (kernel_size, kernel_size)
        yy (ndarray): with the shape (kernel_size, kernel_size)
    """
    ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
    xx, yy = np.meshgrid(ax, ax)
    xy = np.hstack(
        (
            xx.reshape((kernel_size * kernel_size, 1)),
            yy.reshape(kernel_size * kernel_size, 1),
        )
    ).reshape(kernel_size, kernel_size, 2)
    return xy, xx, yy


def sigma_matrix2(sig_x, sig_y, theta):
    """Calculate the rotated sigma matrix (two dimensional matrix).
    Args:
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
    Returns:
        ndarray: Rotated sigma matrix.
    """
    d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
    u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))


def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
    """Generate a bivariate isotropic or anisotropic Gaussian kernel.
    In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
        isotropic (bool):
    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    if isotropic:
        sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
    else:
        sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    kernel = pdf2(sigma_matrix, grid)
    kernel = kernel / np.sum(kernel)
    return kernel


size = 99
sigma = 10
blur_kernel = bivariate_Gaussian(size, sigma, sigma, 0, grid=None, isotropic=True)
blur_kernel = blur_kernel / blur_kernel[size // 2, size // 2]
#############################


def get_flow(points, optical_flow, video_len):
    for i in range(video_len - 1):
        p = points[i]
        p1 = points[i + 1]
        optical_flow[i + 1, p[1], p[0], 0] = p1[0] - p[0]
        optical_flow[i + 1, p[1], p[0], 1] = p1[1] - p[1]

    return optical_flow


def process_points(points, frames=49):
    defualt_points = [[128, 128]] * frames

    if len(points) < 2:
        return defualt_points

    elif len(points) >= frames:
        skip = len(points) // frames
        return points[::skip][: frames - 1] + points[-1:]
    else:
        insert_num = frames - len(points)
        insert_num_dict = {}
        interval = len(points) - 1
        n = insert_num // interval
        m = insert_num % interval
        for i in range(interval):
            insert_num_dict[i] = n
        for i in range(m):
            insert_num_dict[i] += 1

        res = []
        for i in range(interval):
            insert_points = []
            x0, y0 = points[i]
            x1, y1 = points[i + 1]

            delta_x = x1 - x0
            delta_y = y1 - y0
            for j in range(insert_num_dict[i]):
                x = x0 + (j + 1) / (insert_num_dict[i] + 1) * delta_x
                y = y0 + (j + 1) / (insert_num_dict[i] + 1) * delta_y
                insert_points.append([int(x), int(y)])

            res += points[i : i + 1] + insert_points
        res += points[-1:]
        return res


def read_points_from_list(traj_list, video_len=16, reverse=False):
    points = []
    for point in traj_list:
        if isinstance(point, str):
            x, y = point.strip().split(",")
        else:
            x, y = point[0], point[1]
        points.append((int(x), int(y)))
    if reverse:
        points = points[::-1]

    if len(points) > video_len:
        skip = len(points) // video_len
        points = points[::skip]
    points = points[:video_len]

    return points


def read_points_from_file(file, video_len=16, reverse=False):
    with open(file, "r") as f:
        lines = f.readlines()
    points = []
    for line in lines:
        x, y = line.strip().split(",")
        points.append((int(x), int(y)))
    if reverse:
        points = points[::-1]

    if len(points) > video_len:
        skip = len(points) // video_len
        points = points[::skip]
    points = points[:video_len]

    return points


def process_traj(trajs_list, num_frames, video_size, device="cpu"):
    if trajs_list and trajs_list[0] and (not isinstance(trajs_list[0][0], (list, tuple))):
        tmp = trajs_list
        trajs_list = [tmp]

    optical_flow = np.zeros((num_frames, video_size[0], video_size[1], 2), dtype=np.float32)
    processed_points = []
    for traj_list in trajs_list:
        points = read_points_from_list(traj_list, video_len=num_frames)
        xy_range = 256
        h, w = video_size
        points = process_points(points, num_frames)
        points = [[int(w * x / xy_range), int(h * y / xy_range)] for x, y in points]
        optical_flow = get_flow(points, optical_flow, video_len=num_frames)
        processed_points.append(points)

    print(f"received {len(trajs_list)} trajectorie(s)")

    for i in range(1, num_frames):
        optical_flow[i] = cv2.filter2D(optical_flow[i], -1, blur_kernel)

    optical_flow = torch.tensor(optical_flow).to(device)

    return optical_flow, processed_points


def fn_vis_realtime_traj(traj_list):
    points = process_points(traj_list)
    img = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255
    for i in range(len(points) - 1):
        p = points[i]
        p1 = points[i + 1]
        cv2.line(img, p, p1, (255, 0, 0), 2)
    return img


def fn_vis_traj(traj_list):
    # global traj_list
    points = process_points(traj_list)
    imgs = []
    for idx in range(len(points)):
        bg_img = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255
        for i in range(len(points) - 1):
            p = points[i]
            p1 = points[i + 1]
            cv2.line(bg_img, p, p1, (255, 0, 0), 2)

            if i == idx:
                cv2.circle(bg_img, p, 2, (0, 255, 0), 20)

        if idx == (len(points) - 1):
            cv2.circle(bg_img, points[-1], 2, (0, 255, 0), 20)

        imgs.append(bg_img.astype(np.uint8))

    fps = 8
    with tempfile.NamedTemporaryFile(dir=tempfile_dir, suffix=".mp4", delete=False) as f:
        path = f.name
    writer = imageio.get_writer(path, format="mp4", mode="I", fps=fps)
    for img in imgs:
        writer.append_data(img)

    writer.close()

    return path


def create_grid_image(width, height, grid_size):
    img = Image.new("RGBA", (width, height), (255, 255, 255, 0))
    draw = ImageDraw.Draw(img)

    for x in range(0, width, grid_size):
        draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1)

    for y in range(0, height, grid_size):
        draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1)

    return img


def add_provided_traj(traj_list, traj_name):
    traj_list.clear()
    traj_list += PROVIDED_TRAJS[traj_name]
    traj_str = [f"{traj}" for traj in traj_list]
    img = fn_vis_realtime_traj(traj_list)
    return img, ", ".join(traj_str), gr.update(visible=True)


def add_provided_prompt(prompt_name):
    return PROVIDED_PROMPTS[prompt_name]


def add_traj_point(
    traj_list,
    evt: gr.SelectData,
):
    # global traj_list
    traj_list.append(evt.index)
    traj_list[-1][0], traj_list[-1][1] = int(traj_list[-1][0]), int(traj_list[-1][1])
    img = fn_vis_realtime_traj(traj_list)
    traj_str = [f"{traj}" for traj in traj_list]
    return img, ", ".join(traj_str)


def fn_traj_droplast(traj_list):
    # global traj_list

    if traj_list:
        traj_list.pop()

    if traj_list:
        img = fn_vis_realtime_traj(traj_list)
        traj_str = [f"{traj}" for traj in traj_list]

        return img, ", ".join(traj_str), gr.update(visible=True)
    else:
        return (
            np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255,
            "Click to specify trajectory",
            gr.update(visible=True),
        )


def fn_traj_reset(traj_list):
    # global traj_list
    traj_list.clear()
    # traj_list = []
    return (
        np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255,
        "Click to specify trajectory",
        gr.update(visible=True),
    )


def scale_traj_list_to_256(traj_list, canvas_width, canvas_height):
    scale_x = 256 / canvas_width
    scale_y = 256 / canvas_height
    scaled_traj_list = [[int(x * scale_x), int(y * scale_y)] for x, y in traj_list]
    return scaled_traj_list


###########################################

import math
from typing import List, Union

from diffusion_video import SATVideoDiffusionEngine
from einops import rearrange, repeat
from omegaconf import ListConfig
from PIL import Image, ImageDraw
from torchvision.io import write_video
from torchvision.utils import flow_to_image

from sat.arguments import set_random_seed
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint


model = None


def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
    batch = {}
    batch_uc = {}

    for key in keys:
        if key == "txt":
            batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
            batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
        else:
            batch[key] = value_dict[key]

    if T is not None:
        batch["num_video_frames"] = T

    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc


def get_unique_embedder_keys_from_conditioner(conditioner):
    return list({x.input_key for x in conditioner.embedders})


def draw_points(video, points):
    """
    Draw points onto video frames.

    Parameters:
        video (torch.tensor): Video tensor with shape [T, H, W, C], where T is the number of frames,
                            H is the height, W is the width, and C is the number of channels.
        points (list): Positions of points to be drawn as a tensor with shape [T, 2],
                            each point contains x and y coordinates.

    Returns:
        torch.tensor: The video tensor after drawing points, maintaining the same shape [T, H, W, C].
    """

    T = video.shape[0]
    N = len(points)
    device = video.device
    dtype = video.dtype
    video = video.cpu().numpy().copy()
    traj = np.zeros(video.shape[-3:], dtype=np.uint8)  # [H, W, C]
    for t in range(1, T):
        cv2.line(traj, tuple(points[t - 1]), tuple(points[t]), (255, 1, 1), 2)
    for t in range(T):
        mask = traj[..., -1] > 0
        mask = repeat(mask, "h w -> h w c", c=3)
        alpha = 0.7
        video[t][mask] = video[t][mask] * (1 - alpha) + traj[mask] * alpha
        cv2.circle(video[t], tuple(points[t]), 3, (160, 230, 100), -1)
    video = torch.from_numpy(video).to(device, dtype)
    return video


def save_video_as_grid_and_mp4(video_batch: torch.Tensor, fps: int = 8, args=None, key=None, traj_points=None):
    with tempfile.NamedTemporaryFile(dir=tempfile_dir, suffix=".mp4", delete=False) as f:
        path = f.name
    vid = video_batch[0]
    x = rearrange(vid, "t c h w -> t h w c")
    x = x.mul(255).add(0.5).clamp(0, 255).to("cpu", torch.uint8)  # [T H W C]

    if traj_points is not None:
        # traj video
        x = draw_points(x, traj_points)
        with tempfile.NamedTemporaryFile(dir=tempfile_dir, suffix=".mp4", delete=False) as f:
            traj_path = f.name
        write_video(
            traj_path,
            x,
            fps=fps,
            video_codec="libx264",
            options={"crf": "23"},
        )
        print("write video success.")
        return [path, traj_path]

    return [path]


def delete_old_files(folder_path="/tmp/Tora", hours=48, check_interval=3600 * 24):
    """
    Periodically checks and deletes files in the specified folder that were created more than the specified number of hours ago.

    :param folder_path: The path of the folder to check
    :param hours: The number of hours after which files will be deleted (default 48 hours)
    :param check_interval: The interval (in seconds) at which to check for files (default once per day)
    """
    while True:
        print("Checking temporary files...")
        try:
            # Get the current time in seconds since the epoch
            now = time.time()
            # Calculate the cutoff time in seconds
            cutoff_time = now - hours * 3600

            # Iterate over all files in the folder
            for filename in os.listdir(folder_path):
                file_path = os.path.join(folder_path, filename)

                # Ensure it is a file and not a directory
                if os.path.isfile(file_path):
                    # Get the creation time of the file
                    creation_time = os.path.getctime(file_path)

                    # Check if the file is older than the specified time
                    if creation_time < cutoff_time:
                        try:
                            os.remove(file_path)
                            print(f"Deleted file: {file_path}, creation time: {creation_time}")
                        except Exception as e:
                            print(f"Error deleting file: {file_path}. Error: {e}")
        except Exception as e:
            print(f"Error checking files: {e}")

        # Sleep for the specified interval before checking again
        time.sleep(check_interval)


def cold_start(args):
    global model

    if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
        os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
        os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
        os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]

    model_cls = SATVideoDiffusionEngine
    if isinstance(model_cls, type):
        model = get_model(args, model_cls)
    else:
        model = model_cls

    load_checkpoint(model, args)
    model.eval()


def model_run_v2(prompt, seed, traj_list, n_samples=1):
    global model

    image_size = [480, 720]
    sampling_num_frames = 13  # Must be 13, 11 or 9
    latent_channels = 16
    sampling_fps = 8

    sample_func = model.sample
    T, H, W, C, F = sampling_num_frames, image_size[0], image_size[1], latent_channels, 8
    num_samples = [1]
    force_uc_zero_embeddings = ["txt"]
    device = model.device

    # global traj_list
    global canvas_width, canvas_height
    traj_list_range_video = traj_list.copy()
    traj_list_range_256 = scale_traj_list_to_256(traj_list, canvas_width, canvas_height)

    with torch.no_grad():
        set_random_seed(seed)
        total_num_frames = (T - 1) * 4 + 1  # T is the video latent size, 13 * 4 = 52

        video_flow, points = process_traj(traj_list_range_256, total_num_frames, image_size, device=device)
        video_flow = video_flow.unsqueeze_(0)

        if video_flow is not None:
            model.to("cpu")  # move model to cpu, run vae on gpu only.
            tmp = rearrange(video_flow[0], "T H W C -> T C H W")
            video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda")  # [1 T C H W]

            del tmp
            video_flow = (
                rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
            )
            torch.cuda.empty_cache()
            video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous()  # for uncondition
            model.first_stage_model.to(device)
            video_flow = model.encode_first_stage(video_flow, None)
            video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
            model.to(device)

        value_dict = {
            "prompt": prompt,
            "negative_prompt": "",
            "num_frames": torch.tensor(T).unsqueeze(0),
        }

        batch, batch_uc = get_batch(
            get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
        )

        c, uc = model.conditioner.get_unconditional_conditioning(
            batch,
            batch_uc=batch_uc,
            force_uc_zero_embeddings=force_uc_zero_embeddings,
        )

        for k in c:
            if not k == "crossattn":
                c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))

        for index in range(1):  # num_samples
            # reload model on GPU
            model.to(device)

            samples_z = sample_func(
                c,
                uc=uc,
                batch_size=1,
                shape=(T, C, H // F, W // F),
                video_flow=video_flow,
            )
            samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()

            # Unload the model from GPU to save GPU memory
            model.to("cpu")
            torch.cuda.empty_cache()
            first_stage_model = model.first_stage_model
            first_stage_model = first_stage_model.to(device)

            latent = 1.0 / model.scale_factor * samples_z

            # Decode latent serial to save GPU memory
            recons = []
            loop_num = (T - 1) // 2
            for i in range(loop_num):
                if i == 0:
                    start_frame, end_frame = 0, 3
                else:
                    start_frame, end_frame = i * 2 + 1, i * 2 + 3
                if i == loop_num - 1:
                    clear_fake_cp_cache = True
                else:
                    clear_fake_cp_cache = False
                with torch.no_grad():
                    recon = first_stage_model.decode(
                        latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
                    )

                recons.append(recon)

            recon = torch.cat(recons, dim=2).to(torch.float32)
            samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
            samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
            # [b, f, c, h, w]
            file_path_list = save_video_as_grid_and_mp4(
                samples,
                fps=sampling_fps,
                traj_points=process_points(traj_list_range_video),  # interpolate to 49 points
            )
            print(file_path_list)

        del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
        gc.collect()
        torch.cuda.empty_cache()

        return gr.update(value=file_path_list[1], height=image_size[0], width=image_size[1])


def main(args):
    global canvas_width, canvas_height
    canvas_width, canvas_height, grid_size = 720, 480, 120
    grid_image = create_grid_image(canvas_width, canvas_height, grid_size)

    global PROVIDED_TRAJS
    # scale provided trajs
    PROVIDED_TRAJS = {
        name: [[int(x * (canvas_width / 256)), int(y * (canvas_height / 256))] for x, y in points]
        for name, points in PROVIDED_TRAJS.items()
    }

    demo = gr.Blocks()
    with demo:
        # gr.Markdown(title)
        # gr.Markdown(description)
        gr.Markdown("""
            <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
                Tora
            </div>
            <div style="text-align: center;font-size: 20px;">
                <a href="https://github.com/alibaba/Tora">Github</a> |
                <a href="https://ali-videoai.github.io/tora_video/">Project Page</a> |
                <a href="https://arxiv.org/abs/2407.21705">arXiv</a>
            </div>
            """)

        with gr.Column():
            with gr.Row():
                with gr.Column():
                    # step1.2 - object motion control - draw yourself
                    gr.Markdown("---\n## Step 1/2: Draw A Trajectory", show_label=False, visible=True)
                    gr.Markdown(
                        "\n 1. **Click on the `Canvas` to create a trajectory.** Each click adds a new point to the trajectory. \
                        \n 2. Click on `Visualize Trajectory` to view the trajectory as a video; \
                        \n 3. Click on `Reset Trajectory` to clear the trajectory. (Currently, this demo does not support multi-trajectory control. To achieve multi-trajectory control, you can use the command line available on <a href='https://github.com/alibaba/Tora'>GitHub</a>.)",
                        show_label=False,
                        visible=True,
                    )

                    traj_args = gr.Textbox(value="", label="Points of Trajectory", visible=True)
                    traj_list = gr.State([])
                    with gr.Row():
                        traj_vis = gr.Button(value="Visualize Trajectory", visible=True)
                        traj_reset = gr.Button(value="Reset Trajectory", visible=True)
                        traj_droplast = gr.Button(value="Drop Last Point", visible=True)

                with gr.Column():
                    traj_input = gr.Image(
                        grid_image,
                        width=canvas_width // 2,
                        height=canvas_height // 2,
                        label="Canvas for Drawing",
                        visible=True,
                    )

                    vis_traj = gr.Video(
                        value=None,
                        label="Trajectory",
                        visible=True,
                        width=canvas_width // 2,
                        height=canvas_height // 2,
                    )

            # step2 - Add prompt and Generate videos
            with gr.Row():
                with gr.Column():
                    step3_prompt_generate = gr.Markdown(
                        "---\n## Step 2/2: Add prompt(Highly recommend using GPT-4o for refinement).",
                        show_label=False,
                        visible=True,
                    )
                    prompt = gr.Textbox(value="", label="Prompt", interactive=True, visible=True)
                    n_samples = gr.Number(value=1, precision=0, interactive=True, label="n_samples", visible=False)
                    seed = gr.Number(value=1234, precision=0, interactive=True, label="Seed", visible=True)
                    start = gr.Button(value="Generate", visible=True)
                with gr.Column():
                    gen_video = gr.Video(value=None, label="Generate Video", visible=True)
                    # gen_video = gr.Gallery(value=None, label="Generate Video", visible=True)

            # traj examples
            with gr.Column():
                gr.Markdown("---\n## Trajectory Examples", show_label=False, visible=True)
                with gr.Row():
                    traj_1 = gr.Button(value="circle", visible=True)
                    traj_2 = gr.Button(value="spiral", visible=True)
                    traj_3 = gr.Button(value="coaster", visible=True)
                    traj_4 = gr.Button(value="dance", visible=True)
                with gr.Row():
                    traj_5 = gr.Button(value="infinity", visible=True)
                    traj_6 = gr.Button(value="pause", visible=True)
                    traj_7 = gr.Button(value="shake", visible=True)
                    traj_8 = gr.Button(value="wave", visible=True)

            # prompt examples
            with gr.Column():
                gr.Markdown("---\n## Prompt Examples", show_label=False, visible=True)
                with gr.Row():
                    prompt_1 = gr.Button(value="rubber duck", visible=True)
                    prompt_2 = gr.Button(value="dandelion", visible=True)
                    prompt_3 = gr.Button(value="golden retriever", visible=True)
                    prompt_4 = gr.Button(value="squirrel", visible=True)

        traj_1.click(fn=add_provided_traj, inputs=[traj_list, traj_1], outputs=[traj_input, traj_args, traj_input])
        traj_2.click(fn=add_provided_traj, inputs=[traj_list, traj_2], outputs=[traj_input, traj_args, traj_input])
        traj_3.click(fn=add_provided_traj, inputs=[traj_list, traj_3], outputs=[traj_input, traj_args, traj_input])
        traj_4.click(fn=add_provided_traj, inputs=[traj_list, traj_4], outputs=[traj_input, traj_args, traj_input])
        traj_5.click(fn=add_provided_traj, inputs=[traj_list, traj_5], outputs=[traj_input, traj_args, traj_input])
        traj_6.click(fn=add_provided_traj, inputs=[traj_list, traj_6], outputs=[traj_input, traj_args, traj_input])
        traj_7.click(fn=add_provided_traj, inputs=[traj_list, traj_7], outputs=[traj_input, traj_args, traj_input])
        traj_8.click(fn=add_provided_traj, inputs=[traj_list, traj_8], outputs=[traj_input, traj_args, traj_input])

        prompt_1.click(fn=add_provided_prompt, inputs=prompt_1, outputs=prompt)
        prompt_2.click(fn=add_provided_prompt, inputs=prompt_2, outputs=prompt)
        prompt_3.click(fn=add_provided_prompt, inputs=prompt_3, outputs=prompt)
        prompt_4.click(fn=add_provided_prompt, inputs=prompt_4, outputs=prompt)

        traj_vis.click(
            fn=fn_vis_traj,
            inputs=traj_list,
            outputs=[vis_traj],
        )
        traj_input.select(fn=add_traj_point, inputs=traj_list, outputs=[traj_input, traj_args])
        traj_droplast.click(fn=fn_traj_droplast, inputs=traj_list, outputs=[traj_input, traj_args, traj_input])
        traj_reset.click(fn=fn_traj_reset, inputs=traj_list, outputs=[traj_input, traj_args, traj_input])

        # global traj_list
        start.click(fn=model_run_v2, inputs=[prompt, seed, traj_list, n_samples], outputs=gen_video)

        gr.Markdown(article)

    demo.queue(max_size=32).launch(**args)


if __name__ == "__main__":
    # python app.py --load {model_path}

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--listen",
        type=str,
        default="0.0.0.0" if "SPACE_ID" in os.environ else "127.0.0.1",
        help="IP to listen on for connections to Gradio",
    )
    parser.add_argument("--username", type=str, default="", help="Username for authentication")
    parser.add_argument("--password", type=str, default="", help="Password for authentication")
    parser.add_argument(
        "--server_port",
        type=int,
        default=0,
        help="Port to run the server listener on",
    )
    parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
    parser.add_argument("--share", action="store_true", help="Share the gradio UI")

    parser.add_argument(
        "--base", type=str, default="configs/tora/model/cogvideox_5b_tora.yaml configs/tora/inference_sparse.yaml"
    )
    parser.add_argument("--load", type=str, default="ckpts/tora/t2v/")

    args = parser.parse_args()

    launch_kwargs = {}
    launch_kwargs["server_name"] = args.listen

    if args.username and args.password:
        launch_kwargs["auth"] = (args.username, args.password)
    if args.server_port:
        launch_kwargs["server_port"] = args.server_port
    if args.inbrowser:
        launch_kwargs["inbrowser"] = args.inbrowser
    if args.share:
        launch_kwargs["share"] = args.share

    from arguments import get_args

    # TODO: passing 'base' params through the command line
    tora_args_list = [
        "--base",
        "configs/tora/model/cogvideox_5b_tora.yaml",
        "configs/tora/inference_sparse.yaml",
        "--load",
        args.load,
    ]
    tora_args = get_args(tora_args_list)
    tora_args = argparse.Namespace(**vars(tora_args))

    tora_args.model_config.first_stage_config.params.cp_size = 1
    tora_args.model_config.network_config.params.transformer_args.model_parallel_size = 1
    tora_args.model_config.network_config.params.transformer_args.checkpoint_activations = False
    tora_args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
    tora_args.model_config.en_and_decode_n_samples_a_time = 1

    cold_start(args=tora_args)
    print("******************** model loaded ********************")

    threading.Thread(target=delete_old_files, args=("/tmp/Tora", 48, 3600 * 24), daemon=True).start()

    main(launch_kwargs)
