sat/app.py (1,031 lines of code) (raw):
# borrowed from https://github.com/TencentARC/MotionCtrl/blob/main/app.py
import argparse
import gc
import os
import tempfile
import threading
import time
import cv2
import gradio as gr
import imageio
import numpy as np
import torch
tempfile_dir = "/tmp/Tora"
os.makedirs(tempfile_dir, exist_ok=True)
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
SPACE_ID = os.environ.get("SPACE_ID", "")
os.system("modelscope download --model=xiaoche/Tora --local_dir ./ckpts")
#### Description ####
title = r"""<h1 align="center">Tora: Trajectory-oriented Diffusion Transformer for Video Generation</h1>"""
description = r""""""
article = r"""
---
📝 **Citation**
<br>
```bibtex
@misc{zhang2024toratrajectoryorienteddiffusiontransformer,
title={Tora: Trajectory-oriented Diffusion Transformer for Video Generation},
author={Zhenghao Zhang and Junchao Liao and Menghao Li and Zuozhuo Dai and Bingxue Qiu and Siyu Zhu and Long Qin and Weizhi Wang},
year={2024},
eprint={2407.21705},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2407.21705},
}
```
"""
css = """
.gradio-container {width: 85% !important}
.gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important;}
span.svelte-s1r2yt {font-size: 17px !important; font-weight: bold !important; color: #d30f2f !important;}
button {border-radius: 8px !important;}
.add_button {background-color: #4CAF50 !important;}
.remove_button {background-color: #f44336 !important;}
.clear_button {background-color: gray !important;}
.mask_button_group {gap: 10px !important;}
.video {height: 300px !important;}
.image {height: 300px !important;}
.video .wrap.svelte-lcpz3o {display: flex !important; align-items: center !important; justify-content: center !important;}
.video .wrap.svelte-lcpz3o > :first-child {height: 100% !important;}
.margin_center {width: 50% !important; margin: auto !important;}
.jc_center {justify-content: center !important;}
"""
traj_list = []
traj_list_range_256 = []
canvas_width, canvas_height = 256, 256
# Note that the coordinates passed to the model must not exceed 256.
# xy range 256
PROVIDED_TRAJS = {
"circle": [
[120, 194],
[144, 193],
[155, 189],
[158, 170],
[160, 153],
[159, 123],
[152, 113],
[136, 100],
[124, 100],
[108, 100],
[101, 106],
[90, 110],
[84, 129],
[79, 146],
[78, 165],
[83, 182],
[87, 189],
[94, 192],
[100, 194],
[106, 194],
[112, 194],
[118, 195],
],
"spiral": [
[100, 127],
[105, 117],
[122, 117],
[132, 129],
[133, 158],
[125, 181],
[108, 189],
[92, 185],
[84, 179],
[79, 163],
[75, 142],
[73, 118],
[75, 82],
[91, 63],
[115, 52],
[139, 46],
[154, 55],
[167, 93],
[175, 112],
[177, 137],
[177, 158],
[177, 171],
[175, 188],
[173, 204],
],
"coaster": [
[40, 208],
[40, 148],
[40, 100],
[52, 58],
[60, 57],
[74, 68],
[78, 90],
[84, 123],
[88, 148],
[96, 168],
[100, 181],
[102, 188],
[105, 192],
[113, 118],
[119, 80],
[128, 68],
[145, 109],
[149, 155],
[157, 175],
[161, 184],
[164, 184],
[172, 166],
[183, 107],
[189, 84],
[198, 76],
],
"dance": [
[81, 112],
[86, 112],
[92, 112],
[100, 113],
[102, 114],
[97, 115],
[92, 114],
[86, 112],
[81, 112],
[80, 112],
[84, 113],
[89, 114],
[95, 114],
[101, 114],
[102, 114],
[103, 124],
[105, 137],
[109, 156],
[114, 172],
[119, 180],
[124, 184],
[131, 181],
[140, 168],
[146, 152],
[150, 128],
[151, 117],
[152, 116],
[156, 116],
[163, 115],
[169, 116],
[175, 116],
[173, 116],
[167, 116],
[162, 114],
[157, 114],
[152, 115],
[156, 115],
[163, 115],
[168, 115],
[174, 116],
[175, 116],
[168, 116],
[162, 116],
[152, 114],
[149, 134],
[145, 156],
[139, 168],
[130, 183],
[118, 180],
[112, 170],
[107, 151],
[102, 128],
[103, 117],
[96, 113],
[88, 113],
[83, 112],
[80, 112],
],
"infinity": [
[60, 141],
[71, 127],
[92, 120],
[112, 123],
[130, 145],
[145, 163],
[167, 178],
[189, 187],
[206, 176],
[213, 147],
[208, 124],
[190, 112],
[176, 111],
[158, 124],
[145, 147],
[125, 172],
[104, 189],
[72, 189],
[59, 184],
[55, 153],
[57, 140],
[75, 119],
[112, 118],
[129, 142],
[149, 163],
[168, 180],
[194, 186],
[206, 175],
[211, 159],
[212, 149],
[212, 134],
[206, 122],
[180, 112],
[163, 116],
[149, 138],
[128, 170],
[108, 184],
[86, 190],
[63, 181],
[57, 152],
[57, 139],
],
"pause": [
[98, 186],
[100, 188],
[98, 186],
[100, 188],
[101, 187],
[104, 187],
[111, 184],
[116, 176],
[125, 162],
[132, 140],
[136, 119],
[137, 104],
[138, 96],
[139, 94],
[140, 94],
[140, 96],
[138, 98],
[138, 96],
[136, 94],
[137, 92],
[140, 92],
[144, 92],
[149, 92],
[152, 92],
[151, 92],
[147, 92],
[142, 92],
[140, 92],
[139, 95],
[139, 105],
[141, 122],
[142, 143],
[140, 167],
[136, 184],
[135, 188],
[132, 195],
[132, 192],
[131, 192],
[131, 192],
[130, 192],
[130, 195],
],
"shake": [
[103, 89],
[104, 89],
[106, 89],
[107, 89],
[108, 89],
[109, 89],
[110, 89],
[111, 89],
[112, 89],
[113, 89],
[114, 89],
[115, 89],
[116, 89],
[117, 89],
[118, 89],
[119, 89],
[120, 89],
[122, 89],
[123, 89],
[124, 89],
[125, 89],
[126, 89],
[127, 88],
[128, 88],
[129, 88],
[130, 88],
[131, 88],
[133, 87],
[136, 86],
[137, 86],
[138, 86],
[139, 86],
[140, 86],
[141, 86],
[142, 86],
[143, 86],
[144, 86],
[145, 86],
[146, 87],
[147, 87],
[148, 87],
[149, 87],
[148, 87],
[146, 87],
[145, 88],
[144, 88],
[142, 89],
[141, 89],
[140, 90],
[140, 91],
[138, 91],
[137, 92],
[136, 92],
[136, 93],
[135, 93],
[134, 93],
[133, 93],
[132, 93],
[131, 93],
[130, 93],
[129, 93],
[128, 93],
[127, 92],
[125, 92],
[124, 92],
[123, 92],
[122, 92],
[121, 92],
[120, 92],
[119, 92],
[118, 92],
[117, 92],
[116, 92],
[115, 92],
[113, 92],
[112, 92],
[111, 92],
[110, 92],
[109, 92],
[108, 92],
[108, 91],
[108, 90],
[109, 90],
[110, 90],
[111, 89],
[112, 89],
[113, 89],
[114, 89],
[115, 89],
[115, 88],
[116, 88],
[117, 88],
[118, 88],
[118, 87],
[119, 87],
[120, 87],
[121, 87],
[122, 86],
[123, 86],
[124, 86],
[125, 86],
[126, 85],
[127, 85],
[128, 85],
[129, 85],
[130, 85],
[131, 85],
[132, 85],
[133, 85],
[134, 85],
[135, 85],
[136, 85],
[137, 85],
[138, 85],
[139, 85],
[140, 85],
[141, 85],
[142, 85],
[143, 85],
[143, 84],
[144, 84],
[145, 84],
[146, 84],
[147, 84],
[148, 84],
[149, 84],
[148, 84],
[147, 84],
[145, 84],
[144, 84],
[143, 84],
[142, 84],
[141, 84],
[140, 85],
[139, 85],
[138, 85],
[137, 86],
[136, 86],
[136, 87],
[135, 87],
[134, 87],
[133, 87],
[132, 88],
[131, 88],
[130, 88],
[129, 88],
[129, 89],
[128, 89],
[127, 89],
[126, 89],
[125, 89],
[124, 90],
[123, 90],
[122, 90],
[121, 90],
[120, 91],
[119, 91],
[118, 91],
[117, 91],
[116, 91],
[115, 91],
[114, 91],
[113, 91],
[112, 91],
[111, 91],
[110, 91],
[109, 91],
[109, 90],
[108, 90],
[110, 90],
[111, 90],
[113, 90],
[114, 90],
[115, 90],
[116, 90],
[118, 90],
[120, 90],
[121, 90],
[122, 90],
[123, 90],
[124, 90],
[126, 90],
[127, 90],
[128, 90],
[129, 90],
[130, 90],
[131, 90],
[132, 90],
[133, 90],
[134, 90],
[135, 90],
[136, 90],
[137, 90],
[138, 90],
[139, 90],
[140, 90],
[141, 89],
[142, 89],
[143, 89],
[144, 89],
[145, 89],
[146, 89],
[147, 89],
[147, 89],
[147, 89],
],
"wave": [
[16, 152],
[23, 138],
[39, 122],
[54, 115],
[75, 118],
[88, 130],
[93, 150],
[89, 176],
[75, 184],
[63, 177],
[65, 152],
[77, 135],
[98, 121],
[116, 120],
[135, 127],
[148, 136],
[156, 145],
[160, 165],
[158, 176],
[138, 187],
[133, 185],
[129, 148],
[140, 133],
[156, 120],
[177, 118],
[197, 118],
[214, 119],
[225, 118],
],
}
PROVIDED_PROMPTS = {
"dandelion": "A dandelion puff sways gently in the wind, its seeds ready to take flight and spread into the world. The animation style highlights the delicate fibers of the puff, with soft, glowing light surrounding it. The background showcases a lush, green field, hinting at the beauty of nature. As the wind blows, the seeds dance and float away, creating an enchanting visual narrative. The gentle sounds of nature, alongside soft whispers of the breeze, enrich the overall ambiance. This serene scene invites viewers to embrace the moment of letting go, celebrating the cycle of life and new beginnings.",
"golden retriever": "A golden retriever, sporting sleek black sunglasses, with its lengthy fur flowing in the breeze, sprints playfully across a rooftop terrace, recently refreshed by a light rain. The scene unfolds from a distance, the dog's energetic bounds growing larger as it approaches the camera, its tail wagging with unrestrained joy, while droplets of water glisten on the concrete behind it. The overcast sky provides a dramatic backdrop, emphasizing the vibrant golden coat of the canine as it dashes towards the viewer.",
"rubber duck": "A cheerful rubber duck floats serenely in a bathtub filled with bubbles, the soft foam creating an inviting atmosphere. The bathroom setting is warm with bright tiles reflecting soft light. The camera captures playful angles, zeroing in on the duck's bright yellow color and big eyes. Sounds of water gently splashing and laughter fill the background, enhancing the joyous ambiance. This moment invites viewers to embrace nostalgia and childhood fun, evoking a sense of playfulness and relaxation.",
"squirrel": "A squirrel gathering nuts.",
}
#############################
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack(
(
xx.reshape((kernel_size * kernel_size, 1)),
yy.reshape(kernel_size * kernel_size, 1),
)
).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
size = 99
sigma = 10
blur_kernel = bivariate_Gaussian(size, sigma, sigma, 0, grid=None, isotropic=True)
blur_kernel = blur_kernel / blur_kernel[size // 2, size // 2]
#############################
def get_flow(points, optical_flow, video_len):
for i in range(video_len - 1):
p = points[i]
p1 = points[i + 1]
optical_flow[i + 1, p[1], p[0], 0] = p1[0] - p[0]
optical_flow[i + 1, p[1], p[0], 1] = p1[1] - p[1]
return optical_flow
def process_points(points, frames=49):
defualt_points = [[128, 128]] * frames
if len(points) < 2:
return defualt_points
elif len(points) >= frames:
skip = len(points) // frames
return points[::skip][: frames - 1] + points[-1:]
else:
insert_num = frames - len(points)
insert_num_dict = {}
interval = len(points) - 1
n = insert_num // interval
m = insert_num % interval
for i in range(interval):
insert_num_dict[i] = n
for i in range(m):
insert_num_dict[i] += 1
res = []
for i in range(interval):
insert_points = []
x0, y0 = points[i]
x1, y1 = points[i + 1]
delta_x = x1 - x0
delta_y = y1 - y0
for j in range(insert_num_dict[i]):
x = x0 + (j + 1) / (insert_num_dict[i] + 1) * delta_x
y = y0 + (j + 1) / (insert_num_dict[i] + 1) * delta_y
insert_points.append([int(x), int(y)])
res += points[i : i + 1] + insert_points
res += points[-1:]
return res
def read_points_from_list(traj_list, video_len=16, reverse=False):
points = []
for point in traj_list:
if isinstance(point, str):
x, y = point.strip().split(",")
else:
x, y = point[0], point[1]
points.append((int(x), int(y)))
if reverse:
points = points[::-1]
if len(points) > video_len:
skip = len(points) // video_len
points = points[::skip]
points = points[:video_len]
return points
def read_points_from_file(file, video_len=16, reverse=False):
with open(file, "r") as f:
lines = f.readlines()
points = []
for line in lines:
x, y = line.strip().split(",")
points.append((int(x), int(y)))
if reverse:
points = points[::-1]
if len(points) > video_len:
skip = len(points) // video_len
points = points[::skip]
points = points[:video_len]
return points
def process_traj(trajs_list, num_frames, video_size, device="cpu"):
if trajs_list and trajs_list[0] and (not isinstance(trajs_list[0][0], (list, tuple))):
tmp = trajs_list
trajs_list = [tmp]
optical_flow = np.zeros((num_frames, video_size[0], video_size[1], 2), dtype=np.float32)
processed_points = []
for traj_list in trajs_list:
points = read_points_from_list(traj_list, video_len=num_frames)
xy_range = 256
h, w = video_size
points = process_points(points, num_frames)
points = [[int(w * x / xy_range), int(h * y / xy_range)] for x, y in points]
optical_flow = get_flow(points, optical_flow, video_len=num_frames)
processed_points.append(points)
print(f"received {len(trajs_list)} trajectorie(s)")
for i in range(1, num_frames):
optical_flow[i] = cv2.filter2D(optical_flow[i], -1, blur_kernel)
optical_flow = torch.tensor(optical_flow).to(device)
return optical_flow, processed_points
def fn_vis_realtime_traj(traj_list):
points = process_points(traj_list)
img = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255
for i in range(len(points) - 1):
p = points[i]
p1 = points[i + 1]
cv2.line(img, p, p1, (255, 0, 0), 2)
return img
def fn_vis_traj(traj_list):
# global traj_list
points = process_points(traj_list)
imgs = []
for idx in range(len(points)):
bg_img = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255
for i in range(len(points) - 1):
p = points[i]
p1 = points[i + 1]
cv2.line(bg_img, p, p1, (255, 0, 0), 2)
if i == idx:
cv2.circle(bg_img, p, 2, (0, 255, 0), 20)
if idx == (len(points) - 1):
cv2.circle(bg_img, points[-1], 2, (0, 255, 0), 20)
imgs.append(bg_img.astype(np.uint8))
fps = 8
with tempfile.NamedTemporaryFile(dir=tempfile_dir, suffix=".mp4", delete=False) as f:
path = f.name
writer = imageio.get_writer(path, format="mp4", mode="I", fps=fps)
for img in imgs:
writer.append_data(img)
writer.close()
return path
def create_grid_image(width, height, grid_size):
img = Image.new("RGBA", (width, height), (255, 255, 255, 0))
draw = ImageDraw.Draw(img)
for x in range(0, width, grid_size):
draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1)
for y in range(0, height, grid_size):
draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1)
return img
def add_provided_traj(traj_list, traj_name):
traj_list.clear()
traj_list += PROVIDED_TRAJS[traj_name]
traj_str = [f"{traj}" for traj in traj_list]
img = fn_vis_realtime_traj(traj_list)
return img, ", ".join(traj_str), gr.update(visible=True)
def add_provided_prompt(prompt_name):
return PROVIDED_PROMPTS[prompt_name]
def add_traj_point(
traj_list,
evt: gr.SelectData,
):
# global traj_list
traj_list.append(evt.index)
traj_list[-1][0], traj_list[-1][1] = int(traj_list[-1][0]), int(traj_list[-1][1])
img = fn_vis_realtime_traj(traj_list)
traj_str = [f"{traj}" for traj in traj_list]
return img, ", ".join(traj_str)
def fn_traj_droplast(traj_list):
# global traj_list
if traj_list:
traj_list.pop()
if traj_list:
img = fn_vis_realtime_traj(traj_list)
traj_str = [f"{traj}" for traj in traj_list]
return img, ", ".join(traj_str), gr.update(visible=True)
else:
return (
np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255,
"Click to specify trajectory",
gr.update(visible=True),
)
def fn_traj_reset(traj_list):
# global traj_list
traj_list.clear()
# traj_list = []
return (
np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255,
"Click to specify trajectory",
gr.update(visible=True),
)
def scale_traj_list_to_256(traj_list, canvas_width, canvas_height):
scale_x = 256 / canvas_width
scale_y = 256 / canvas_height
scaled_traj_list = [[int(x * scale_x), int(y * scale_y)] for x, y in traj_list]
return scaled_traj_list
###########################################
import math
from typing import List, Union
from diffusion_video import SATVideoDiffusionEngine
from einops import rearrange, repeat
from omegaconf import ListConfig
from PIL import Image, ImageDraw
from torchvision.io import write_video
from torchvision.utils import flow_to_image
from sat.arguments import set_random_seed
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
model = None
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders})
def draw_points(video, points):
"""
Draw points onto video frames.
Parameters:
video (torch.tensor): Video tensor with shape [T, H, W, C], where T is the number of frames,
H is the height, W is the width, and C is the number of channels.
points (list): Positions of points to be drawn as a tensor with shape [T, 2],
each point contains x and y coordinates.
Returns:
torch.tensor: The video tensor after drawing points, maintaining the same shape [T, H, W, C].
"""
T = video.shape[0]
N = len(points)
device = video.device
dtype = video.dtype
video = video.cpu().numpy().copy()
traj = np.zeros(video.shape[-3:], dtype=np.uint8) # [H, W, C]
for t in range(1, T):
cv2.line(traj, tuple(points[t - 1]), tuple(points[t]), (255, 1, 1), 2)
for t in range(T):
mask = traj[..., -1] > 0
mask = repeat(mask, "h w -> h w c", c=3)
alpha = 0.7
video[t][mask] = video[t][mask] * (1 - alpha) + traj[mask] * alpha
cv2.circle(video[t], tuple(points[t]), 3, (160, 230, 100), -1)
video = torch.from_numpy(video).to(device, dtype)
return video
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, fps: int = 8, args=None, key=None, traj_points=None):
with tempfile.NamedTemporaryFile(dir=tempfile_dir, suffix=".mp4", delete=False) as f:
path = f.name
vid = video_batch[0]
x = rearrange(vid, "t c h w -> t h w c")
x = x.mul(255).add(0.5).clamp(0, 255).to("cpu", torch.uint8) # [T H W C]
if traj_points is not None:
# traj video
x = draw_points(x, traj_points)
with tempfile.NamedTemporaryFile(dir=tempfile_dir, suffix=".mp4", delete=False) as f:
traj_path = f.name
write_video(
traj_path,
x,
fps=fps,
video_codec="libx264",
options={"crf": "23"},
)
print("write video success.")
return [path, traj_path]
return [path]
def delete_old_files(folder_path="/tmp/Tora", hours=48, check_interval=3600 * 24):
"""
Periodically checks and deletes files in the specified folder that were created more than the specified number of hours ago.
:param folder_path: The path of the folder to check
:param hours: The number of hours after which files will be deleted (default 48 hours)
:param check_interval: The interval (in seconds) at which to check for files (default once per day)
"""
while True:
print("Checking temporary files...")
try:
# Get the current time in seconds since the epoch
now = time.time()
# Calculate the cutoff time in seconds
cutoff_time = now - hours * 3600
# Iterate over all files in the folder
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
# Ensure it is a file and not a directory
if os.path.isfile(file_path):
# Get the creation time of the file
creation_time = os.path.getctime(file_path)
# Check if the file is older than the specified time
if creation_time < cutoff_time:
try:
os.remove(file_path)
print(f"Deleted file: {file_path}, creation time: {creation_time}")
except Exception as e:
print(f"Error deleting file: {file_path}. Error: {e}")
except Exception as e:
print(f"Error checking files: {e}")
# Sleep for the specified interval before checking again
time.sleep(check_interval)
def cold_start(args):
global model
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
model_cls = SATVideoDiffusionEngine
if isinstance(model_cls, type):
model = get_model(args, model_cls)
else:
model = model_cls
load_checkpoint(model, args)
model.eval()
def model_run_v2(prompt, seed, traj_list, n_samples=1):
global model
image_size = [480, 720]
sampling_num_frames = 13 # Must be 13, 11 or 9
latent_channels = 16
sampling_fps = 8
sample_func = model.sample
T, H, W, C, F = sampling_num_frames, image_size[0], image_size[1], latent_channels, 8
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
device = model.device
# global traj_list
global canvas_width, canvas_height
traj_list_range_video = traj_list.copy()
traj_list_range_256 = scale_traj_list_to_256(traj_list, canvas_width, canvas_height)
with torch.no_grad():
set_random_seed(seed)
total_num_frames = (T - 1) * 4 + 1 # T is the video latent size, 13 * 4 = 52
video_flow, points = process_traj(traj_list_range_256, total_num_frames, image_size, device=device)
video_flow = video_flow.unsqueeze_(0)
if video_flow is not None:
model.to("cpu") # move model to cpu, run vae on gpu only.
tmp = rearrange(video_flow[0], "T H W C -> T C H W")
video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
del tmp
video_flow = (
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
)
torch.cuda.empty_cache()
video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
model.first_stage_model.to(device)
video_flow = model.encode_first_stage(video_flow, None)
video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
model.to(device)
value_dict = {
"prompt": prompt,
"negative_prompt": "",
"num_frames": torch.tensor(T).unsqueeze(0),
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
for index in range(1): # num_samples
# reload model on GPU
model.to(device)
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
video_flow=video_flow,
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
# Unload the model from GPU to save GPU memory
model.to("cpu")
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
# Decode latent serial to save GPU memory
recons = []
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
else:
start_frame, end_frame = i * 2 + 1, i * 2 + 3
if i == loop_num - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
# [b, f, c, h, w]
file_path_list = save_video_as_grid_and_mp4(
samples,
fps=sampling_fps,
traj_points=process_points(traj_list_range_video), # interpolate to 49 points
)
print(file_path_list)
del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
gc.collect()
torch.cuda.empty_cache()
return gr.update(value=file_path_list[1], height=image_size[0], width=image_size[1])
def main(args):
global canvas_width, canvas_height
canvas_width, canvas_height, grid_size = 720, 480, 120
grid_image = create_grid_image(canvas_width, canvas_height, grid_size)
global PROVIDED_TRAJS
# scale provided trajs
PROVIDED_TRAJS = {
name: [[int(x * (canvas_width / 256)), int(y * (canvas_height / 256))] for x, y in points]
for name, points in PROVIDED_TRAJS.items()
}
demo = gr.Blocks()
with demo:
# gr.Markdown(title)
# gr.Markdown(description)
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Tora
</div>
<div style="text-align: center;font-size: 20px;">
<a href="https://github.com/alibaba/Tora">Github</a> |
<a href="https://ali-videoai.github.io/tora_video/">Project Page</a> |
<a href="https://arxiv.org/abs/2407.21705">arXiv</a>
</div>
""")
with gr.Column():
with gr.Row():
with gr.Column():
# step1.2 - object motion control - draw yourself
gr.Markdown("---\n## Step 1/2: Draw A Trajectory", show_label=False, visible=True)
gr.Markdown(
"\n 1. **Click on the `Canvas` to create a trajectory.** Each click adds a new point to the trajectory. \
\n 2. Click on `Visualize Trajectory` to view the trajectory as a video; \
\n 3. Click on `Reset Trajectory` to clear the trajectory. (Currently, this demo does not support multi-trajectory control. To achieve multi-trajectory control, you can use the command line available on <a href='https://github.com/alibaba/Tora'>GitHub</a>.)",
show_label=False,
visible=True,
)
traj_args = gr.Textbox(value="", label="Points of Trajectory", visible=True)
traj_list = gr.State([])
with gr.Row():
traj_vis = gr.Button(value="Visualize Trajectory", visible=True)
traj_reset = gr.Button(value="Reset Trajectory", visible=True)
traj_droplast = gr.Button(value="Drop Last Point", visible=True)
with gr.Column():
traj_input = gr.Image(
grid_image,
width=canvas_width // 2,
height=canvas_height // 2,
label="Canvas for Drawing",
visible=True,
)
vis_traj = gr.Video(
value=None,
label="Trajectory",
visible=True,
width=canvas_width // 2,
height=canvas_height // 2,
)
# step2 - Add prompt and Generate videos
with gr.Row():
with gr.Column():
step3_prompt_generate = gr.Markdown(
"---\n## Step 2/2: Add prompt(Highly recommend using GPT-4o for refinement).",
show_label=False,
visible=True,
)
prompt = gr.Textbox(value="", label="Prompt", interactive=True, visible=True)
n_samples = gr.Number(value=1, precision=0, interactive=True, label="n_samples", visible=False)
seed = gr.Number(value=1234, precision=0, interactive=True, label="Seed", visible=True)
start = gr.Button(value="Generate", visible=True)
with gr.Column():
gen_video = gr.Video(value=None, label="Generate Video", visible=True)
# gen_video = gr.Gallery(value=None, label="Generate Video", visible=True)
# traj examples
with gr.Column():
gr.Markdown("---\n## Trajectory Examples", show_label=False, visible=True)
with gr.Row():
traj_1 = gr.Button(value="circle", visible=True)
traj_2 = gr.Button(value="spiral", visible=True)
traj_3 = gr.Button(value="coaster", visible=True)
traj_4 = gr.Button(value="dance", visible=True)
with gr.Row():
traj_5 = gr.Button(value="infinity", visible=True)
traj_6 = gr.Button(value="pause", visible=True)
traj_7 = gr.Button(value="shake", visible=True)
traj_8 = gr.Button(value="wave", visible=True)
# prompt examples
with gr.Column():
gr.Markdown("---\n## Prompt Examples", show_label=False, visible=True)
with gr.Row():
prompt_1 = gr.Button(value="rubber duck", visible=True)
prompt_2 = gr.Button(value="dandelion", visible=True)
prompt_3 = gr.Button(value="golden retriever", visible=True)
prompt_4 = gr.Button(value="squirrel", visible=True)
traj_1.click(fn=add_provided_traj, inputs=[traj_list, traj_1], outputs=[traj_input, traj_args, traj_input])
traj_2.click(fn=add_provided_traj, inputs=[traj_list, traj_2], outputs=[traj_input, traj_args, traj_input])
traj_3.click(fn=add_provided_traj, inputs=[traj_list, traj_3], outputs=[traj_input, traj_args, traj_input])
traj_4.click(fn=add_provided_traj, inputs=[traj_list, traj_4], outputs=[traj_input, traj_args, traj_input])
traj_5.click(fn=add_provided_traj, inputs=[traj_list, traj_5], outputs=[traj_input, traj_args, traj_input])
traj_6.click(fn=add_provided_traj, inputs=[traj_list, traj_6], outputs=[traj_input, traj_args, traj_input])
traj_7.click(fn=add_provided_traj, inputs=[traj_list, traj_7], outputs=[traj_input, traj_args, traj_input])
traj_8.click(fn=add_provided_traj, inputs=[traj_list, traj_8], outputs=[traj_input, traj_args, traj_input])
prompt_1.click(fn=add_provided_prompt, inputs=prompt_1, outputs=prompt)
prompt_2.click(fn=add_provided_prompt, inputs=prompt_2, outputs=prompt)
prompt_3.click(fn=add_provided_prompt, inputs=prompt_3, outputs=prompt)
prompt_4.click(fn=add_provided_prompt, inputs=prompt_4, outputs=prompt)
traj_vis.click(
fn=fn_vis_traj,
inputs=traj_list,
outputs=[vis_traj],
)
traj_input.select(fn=add_traj_point, inputs=traj_list, outputs=[traj_input, traj_args])
traj_droplast.click(fn=fn_traj_droplast, inputs=traj_list, outputs=[traj_input, traj_args, traj_input])
traj_reset.click(fn=fn_traj_reset, inputs=traj_list, outputs=[traj_input, traj_args, traj_input])
# global traj_list
start.click(fn=model_run_v2, inputs=[prompt, seed, traj_list, n_samples], outputs=gen_video)
gr.Markdown(article)
demo.queue(max_size=32).launch(**args)
if __name__ == "__main__":
# python app.py --load {model_path}
parser = argparse.ArgumentParser()
parser.add_argument(
"--listen",
type=str,
default="0.0.0.0" if "SPACE_ID" in os.environ else "127.0.0.1",
help="IP to listen on for connections to Gradio",
)
parser.add_argument("--username", type=str, default="", help="Username for authentication")
parser.add_argument("--password", type=str, default="", help="Password for authentication")
parser.add_argument(
"--server_port",
type=int,
default=0,
help="Port to run the server listener on",
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
parser.add_argument(
"--base", type=str, default="configs/tora/model/cogvideox_5b_tora.yaml configs/tora/inference_sparse.yaml"
)
parser.add_argument("--load", type=str, default="ckpts/tora/t2v/")
args = parser.parse_args()
launch_kwargs = {}
launch_kwargs["server_name"] = args.listen
if args.username and args.password:
launch_kwargs["auth"] = (args.username, args.password)
if args.server_port:
launch_kwargs["server_port"] = args.server_port
if args.inbrowser:
launch_kwargs["inbrowser"] = args.inbrowser
if args.share:
launch_kwargs["share"] = args.share
from arguments import get_args
# TODO: passing 'base' params through the command line
tora_args_list = [
"--base",
"configs/tora/model/cogvideox_5b_tora.yaml",
"configs/tora/inference_sparse.yaml",
"--load",
args.load,
]
tora_args = get_args(tora_args_list)
tora_args = argparse.Namespace(**vars(tora_args))
tora_args.model_config.first_stage_config.params.cp_size = 1
tora_args.model_config.network_config.params.transformer_args.model_parallel_size = 1
tora_args.model_config.network_config.params.transformer_args.checkpoint_activations = False
tora_args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
tora_args.model_config.en_and_decode_n_samples_a_time = 1
cold_start(args=tora_args)
print("******************** model loaded ********************")
threading.Thread(target=delete_old_files, args=("/tmp/Tora", 48, 3600 * 24), daemon=True).start()
main(launch_kwargs)