sat/sample_video.py (347 lines of code) (raw):

import argparse import gc import json import math import os import pickle from pathlib import Path from typing import List, Union import cv2 import imageio import numpy as np import torch import torchvision.transforms as TT from arguments import get_args from diffusion_video import SATVideoDiffusionEngine from einops import rearrange, repeat from omegaconf import ListConfig from PIL import Image from torchvision.io import write_video from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import resize from torchvision.utils import flow_to_image from tqdm import tqdm from utils.flow_utils import process_traj from utils.misc import vis_tensor from sat import mpu from sat.arguments import set_random_seed from sat.model.base_model import get_model from sat.training.model_io import load_checkpoint def read_from_cli(): cnt = 0 try: while True: x = input("Please input English text (Ctrl-D quit): ") yield x.strip(), cnt cnt += 1 except EOFError as e: pass def read_from_file(p, rank=0, world_size=1): with open(p, "r") as fin: cnt = -1 for l in fin: cnt += 1 if cnt % world_size != rank: continue yield l.strip(), cnt def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"): batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def draw_points(video, points): """ Draw points onto video frames. Parameters: video (torch.tensor): Video tensor with shape [T, H, W, C], where T is the number of frames, H is the height, W is the width, and C is the number of channels. points (list): Positions of points to be drawn as a tensor with shape [N, T, 2], each point contains x and y coordinates. Returns: torch.tensor: The video tensor after drawing points, maintaining the same shape [T, H, W, C]. """ T = video.shape[0] N = len(points) device = video.device dtype = video.dtype video = video.cpu().numpy().copy() traj = np.zeros(video.shape[-3:], dtype=np.uint8) # [H, W, C] for n in range(N): for t in range(1, T): cv2.line(traj, tuple(points[n][t - 1]), tuple(points[n][t]), (255, 1, 1), 2) for t in range(T): mask = traj[..., -1] > 0 mask = repeat(mask, "h w -> h w c", c=3) alpha = 0.7 video[t][mask] = video[t][mask] * (1 - alpha) + traj[mask] * alpha for n in range(N): cv2.circle(video[t], tuple(points[n][t]), 3, (160, 230, 100), -1) video = torch.from_numpy(video).to(device, dtype) return video def save_video_as_grid_and_mp4( video_batch: torch.Tensor, save_path: str, name: str, fps: int = 5, args=None, key=None, traj_points=None, prompt="", ): os.makedirs(save_path, exist_ok=True) p = Path(save_path) for i, vid in enumerate(video_batch): x = rearrange(vid, "t c h w -> t h w c") x = x.mul(255).add(0.5).clamp(0, 255).to("cpu", torch.uint8) # [T H W C] os.makedirs(p / "video", exist_ok=True) os.makedirs(p / "prompt", exist_ok=True) if traj_points is not None: os.makedirs(p / "traj", exist_ok=True) os.makedirs(p / "traj_video", exist_ok=True) write_video( p / "video" / f"{name}_{i:06d}.mp4", x, fps=fps, video_codec="libx264", options={"crf": "18"}, ) with open(p / "traj" / f"{name}_{i:06d}.pkl", "wb") as f: pickle.dump(traj_points, f) x = draw_points(x, traj_points) write_video( p / "traj_video" / f"{name}_{i:06d}.mp4", x, fps=fps, video_codec="libx264", options={"crf": "18"}, ) else: write_video( p / "video" / f"{name}_{i:06d}.mp4", x, fps=fps, video_codec="libx264", options={"crf": "18"}, ) with open(p / "prompt" / f"{name}_{i:06d}.txt", "w") as f: f.write(prompt) def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: arr = resize( arr, size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], interpolation=InterpolationMode.BICUBIC, ) else: arr = resize( arr, size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], interpolation=InterpolationMode.BICUBIC, ) h, w = arr.shape[2], arr.shape[3] arr = arr.squeeze(0) delta_h = h - image_size[0] delta_w = w - image_size[1] if reshape_mode == "random" or reshape_mode == "none": top = np.random.randint(0, delta_h + 1) left = np.random.randint(0, delta_w + 1) elif reshape_mode == "center": top, left = delta_h // 2, delta_w // 2 else: raise NotImplementedError arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) return arr def sampling_main(args, model_cls): if isinstance(model_cls, type): model = get_model(args, model_cls) else: model = model_cls load_checkpoint(model, args) model.eval() if args.input_type == "cli": data_iter = read_from_cli() elif args.input_type == "txt": rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() print("rank and world_size", rank, world_size) data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size) else: raise NotImplementedError image_size = [480, 720] if args.image2video: chained_trainsforms = [] chained_trainsforms.append(TT.ToTensor()) transform = TT.Compose(chained_trainsforms) sample_func = model.sample T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8 num_samples = [1] force_uc_zero_embeddings = ["txt"] device = model.device with torch.no_grad(): for text, cnt in tqdm(data_iter): print(text) set_random_seed(args.seed) if args.image2video: if args.flow_from_prompt: text, image_path, flow_files = text.strip().split("@@@") print(flow_files, image_path) else: text, image_path = text.split("@@") image_path = os.path.join(args.img_dir, image_path) assert os.path.exists(image_path), image_path image = Image.open(image_path).convert("RGB") image = image.resize(tuple(reversed(image_size))) image = transform(image).unsqueeze(0).to("cuda") image = image * 2.0 - 1.0 image = image.unsqueeze(2).to(torch.bfloat16) image = model.encode_first_stage(image, None) image = image.permute(0, 2, 1, 3, 4).contiguous() pad_shape = (image.shape[0], T - 1, C, H // F, W // F) image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) else: image = None if args.flow_from_prompt: text, flow_files = text.split("\t") total_num_frames = (T - 1) * 4 + 1 # T is the video latent size, 13 * 4 = 52 if args.no_flow_injection: video_flow = None elif args.flow_from_prompt: assert args.flow_path is not None, "Flow path must be provided if flow_from_prompt is True" p = os.path.join(args.flow_path, flow_files) print(f"Flow path: {p}") video_flow = ( torch.load(p, map_location="cpu", weights_only=True)[:total_num_frames].unsqueeze_(0).cuda() ) elif args.flow_path: print(f"Flow path: {args.flow_path}") video_flow = torch.load(args.flow_path, map_location=device, weights_only=True)[ :total_num_frames ].unsqueeze_(0) elif args.point_path: if type(args.point_path) == str: args.point_path = json.loads(args.point_path) print(f"Point path: {args.point_path}") video_flow, points = process_traj(args.point_path, total_num_frames, image_size, device=device) video_flow = video_flow.unsqueeze_(0) else: print("No flow injection") video_flow = None if video_flow is not None: model.to("cpu") # move model to cpu, run vae on gpu only. tmp = rearrange(video_flow[0], "T H W C -> T C H W") video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W] if args.vis_traj_features: os.makedirs("samples/flow", exist_ok=True) vis_tensor(tmp, *tmp.shape[-2:], "samples/flow/flow1_vis.gif") imageio.mimwrite( "samples/flow/flow2_vis.gif", rearrange(video_flow[0], "T C H W -> T H W C").cpu(), fps=8, loop=0, ) del tmp video_flow = ( rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16) ) video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition model.first_stage_model.to(device) torch.cuda.empty_cache() video_flow = model.encode_first_stage(video_flow, None) video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous() torch.cuda.empty_cache() model.to(device) value_dict = { "prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0), } batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples ) for key in batch: if isinstance(batch[key], torch.Tensor): print(key, batch[key].shape) elif isinstance(batch[key], list): print(key, [len(l) for l in batch[key]]) else: print(key, batch[key]) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in c: if not k == "crossattn": c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) if args.image2video and image is not None: c["concat"] = image uc["concat"] = image for index in range(args.num_samples_per_prompt): if cnt > 0: args.seed = np.random.randint(1e6) set_random_seed(args.seed) # reload model on GPU model.to(device) samples_z = sample_func( c, uc=uc, batch_size=1, shape=(T, C, H // F, W // F), video_flow=video_flow, ) samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() # Unload the model from GPU to save GPU memory model.to("cpu") torch.cuda.empty_cache() first_stage_model = model.first_stage_model first_stage_model = first_stage_model.to(device) latent = 1.0 / model.scale_factor * samples_z # Decode latent serial to save GPU memory recons = [] loop_num = (T - 1) // 2 for i in range(loop_num): if i == 0: start_frame, end_frame = 0, 3 else: start_frame, end_frame = i * 2 + 1, i * 2 + 3 if i == loop_num - 1: clear_fake_cp_cache = True else: clear_fake_cp_cache = False with torch.no_grad(): recon = first_stage_model.decode( latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache ) recons.append(recon) recon = torch.cat(recons, dim=2).to(torch.float32) samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() save_path = args.output_dir name = str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:60] + f"_{index}_seed{args.seed}" if args.flow_from_prompt: name = Path(flow_files).stem if mpu.get_model_parallel_rank() == 0: save_video_as_grid_and_mp4( samples, save_path, name, fps=args.sampling_fps, traj_points=locals().get("points", None), prompt=text, ) del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc gc.collect() if __name__ == "__main__": if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] py_parser = argparse.ArgumentParser(add_help=False) known, args_list = py_parser.parse_known_args() args = get_args(args_list) args = argparse.Namespace(**vars(args), **vars(known)) del args.deepspeed_config args.model_config.first_stage_config.params.cp_size = 1 args.model_config.network_config.params.transformer_args.model_parallel_size = 1 args.model_config.network_config.params.transformer_args.checkpoint_activations = False args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False args.model_config.en_and_decode_n_samples_a_time = 1 sampling_main(args, model_cls=SATVideoDiffusionEngine)