def model_run_v2()

in sat/app.py [0:0]


def model_run_v2(prompt, seed, traj_list, n_samples=1):
    global model

    image_size = [480, 720]
    sampling_num_frames = 13  # Must be 13, 11 or 9
    latent_channels = 16
    sampling_fps = 8

    sample_func = model.sample
    T, H, W, C, F = sampling_num_frames, image_size[0], image_size[1], latent_channels, 8
    num_samples = [1]
    force_uc_zero_embeddings = ["txt"]
    device = model.device

    # global traj_list
    global canvas_width, canvas_height
    traj_list_range_video = traj_list.copy()
    traj_list_range_256 = scale_traj_list_to_256(traj_list, canvas_width, canvas_height)

    with torch.no_grad():
        set_random_seed(seed)
        total_num_frames = (T - 1) * 4 + 1  # T is the video latent size, 13 * 4 = 52

        video_flow, points = process_traj(traj_list_range_256, total_num_frames, image_size, device=device)
        video_flow = video_flow.unsqueeze_(0)

        if video_flow is not None:
            model.to("cpu")  # move model to cpu, run vae on gpu only.
            tmp = rearrange(video_flow[0], "T H W C -> T C H W")
            video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda")  # [1 T C H W]

            del tmp
            video_flow = (
                rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
            )
            torch.cuda.empty_cache()
            video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous()  # for uncondition
            model.first_stage_model.to(device)
            video_flow = model.encode_first_stage(video_flow, None)
            video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
            model.to(device)

        value_dict = {
            "prompt": prompt,
            "negative_prompt": "",
            "num_frames": torch.tensor(T).unsqueeze(0),
        }

        batch, batch_uc = get_batch(
            get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
        )

        c, uc = model.conditioner.get_unconditional_conditioning(
            batch,
            batch_uc=batch_uc,
            force_uc_zero_embeddings=force_uc_zero_embeddings,
        )

        for k in c:
            if not k == "crossattn":
                c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))

        for index in range(1):  # num_samples
            # reload model on GPU
            model.to(device)

            samples_z = sample_func(
                c,
                uc=uc,
                batch_size=1,
                shape=(T, C, H // F, W // F),
                video_flow=video_flow,
            )
            samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()

            # Unload the model from GPU to save GPU memory
            model.to("cpu")
            torch.cuda.empty_cache()
            first_stage_model = model.first_stage_model
            first_stage_model = first_stage_model.to(device)

            latent = 1.0 / model.scale_factor * samples_z

            # Decode latent serial to save GPU memory
            recons = []
            loop_num = (T - 1) // 2
            for i in range(loop_num):
                if i == 0:
                    start_frame, end_frame = 0, 3
                else:
                    start_frame, end_frame = i * 2 + 1, i * 2 + 3
                if i == loop_num - 1:
                    clear_fake_cp_cache = True
                else:
                    clear_fake_cp_cache = False
                with torch.no_grad():
                    recon = first_stage_model.decode(
                        latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
                    )

                recons.append(recon)

            recon = torch.cat(recons, dim=2).to(torch.float32)
            samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
            samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
            # [b, f, c, h, w]
            file_path_list = save_video_as_grid_and_mp4(
                samples,
                fps=sampling_fps,
                traj_points=process_points(traj_list_range_video),  # interpolate to 49 points
            )
            print(file_path_list)

        del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
        gc.collect()
        torch.cuda.empty_cache()

        return gr.update(value=file_path_list[1], height=image_size[0], width=image_size[1])