in sat/app.py [0:0]
def model_run_v2(prompt, seed, traj_list, n_samples=1):
global model
image_size = [480, 720]
sampling_num_frames = 13 # Must be 13, 11 or 9
latent_channels = 16
sampling_fps = 8
sample_func = model.sample
T, H, W, C, F = sampling_num_frames, image_size[0], image_size[1], latent_channels, 8
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
device = model.device
# global traj_list
global canvas_width, canvas_height
traj_list_range_video = traj_list.copy()
traj_list_range_256 = scale_traj_list_to_256(traj_list, canvas_width, canvas_height)
with torch.no_grad():
set_random_seed(seed)
total_num_frames = (T - 1) * 4 + 1 # T is the video latent size, 13 * 4 = 52
video_flow, points = process_traj(traj_list_range_256, total_num_frames, image_size, device=device)
video_flow = video_flow.unsqueeze_(0)
if video_flow is not None:
model.to("cpu") # move model to cpu, run vae on gpu only.
tmp = rearrange(video_flow[0], "T H W C -> T C H W")
video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
del tmp
video_flow = (
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
)
torch.cuda.empty_cache()
video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
model.first_stage_model.to(device)
video_flow = model.encode_first_stage(video_flow, None)
video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
model.to(device)
value_dict = {
"prompt": prompt,
"negative_prompt": "",
"num_frames": torch.tensor(T).unsqueeze(0),
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
for index in range(1): # num_samples
# reload model on GPU
model.to(device)
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
video_flow=video_flow,
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
# Unload the model from GPU to save GPU memory
model.to("cpu")
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
# Decode latent serial to save GPU memory
recons = []
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
else:
start_frame, end_frame = i * 2 + 1, i * 2 + 3
if i == loop_num - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
# [b, f, c, h, w]
file_path_list = save_video_as_grid_and_mp4(
samples,
fps=sampling_fps,
traj_points=process_points(traj_list_range_video), # interpolate to 49 points
)
print(file_path_list)
del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
gc.collect()
torch.cuda.empty_cache()
return gr.update(value=file_path_list[1], height=image_size[0], width=image_size[1])