in sat/sample_video.py [0:0]
def sampling_main(args, model_cls):
if isinstance(model_cls, type):
model = get_model(args, model_cls)
else:
model = model_cls
load_checkpoint(model, args)
model.eval()
if args.input_type == "cli":
data_iter = read_from_cli()
elif args.input_type == "txt":
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
print("rank and world_size", rank, world_size)
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else:
raise NotImplementedError
image_size = [480, 720]
if args.image2video:
chained_trainsforms = []
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
sample_func = model.sample
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
device = model.device
with torch.no_grad():
for text, cnt in tqdm(data_iter):
print(text)
set_random_seed(args.seed)
if args.image2video:
if args.flow_from_prompt:
text, image_path, flow_files = text.strip().split("@@@")
print(flow_files, image_path)
else:
text, image_path = text.split("@@")
image_path = os.path.join(args.img_dir, image_path)
assert os.path.exists(image_path), image_path
image = Image.open(image_path).convert("RGB")
image = image.resize(tuple(reversed(image_size)))
image = transform(image).unsqueeze(0).to("cuda")
image = image * 2.0 - 1.0
image = image.unsqueeze(2).to(torch.bfloat16)
image = model.encode_first_stage(image, None)
image = image.permute(0, 2, 1, 3, 4).contiguous()
pad_shape = (image.shape[0], T - 1, C, H // F, W // F)
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
else:
image = None
if args.flow_from_prompt:
text, flow_files = text.split("\t")
total_num_frames = (T - 1) * 4 + 1 # T is the video latent size, 13 * 4 = 52
if args.no_flow_injection:
video_flow = None
elif args.flow_from_prompt:
assert args.flow_path is not None, "Flow path must be provided if flow_from_prompt is True"
p = os.path.join(args.flow_path, flow_files)
print(f"Flow path: {p}")
video_flow = (
torch.load(p, map_location="cpu", weights_only=True)[:total_num_frames].unsqueeze_(0).cuda()
)
elif args.flow_path:
print(f"Flow path: {args.flow_path}")
video_flow = torch.load(args.flow_path, map_location=device, weights_only=True)[
:total_num_frames
].unsqueeze_(0)
elif args.point_path:
if type(args.point_path) == str:
args.point_path = json.loads(args.point_path)
print(f"Point path: {args.point_path}")
video_flow, points = process_traj(args.point_path, total_num_frames, image_size, device=device)
video_flow = video_flow.unsqueeze_(0)
else:
print("No flow injection")
video_flow = None
if video_flow is not None:
model.to("cpu") # move model to cpu, run vae on gpu only.
tmp = rearrange(video_flow[0], "T H W C -> T C H W")
video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
if args.vis_traj_features:
os.makedirs("samples/flow", exist_ok=True)
vis_tensor(tmp, *tmp.shape[-2:], "samples/flow/flow1_vis.gif")
imageio.mimwrite(
"samples/flow/flow2_vis.gif",
rearrange(video_flow[0], "T C H W -> T H W C").cpu(),
fps=8,
loop=0,
)
del tmp
video_flow = (
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
)
video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
model.first_stage_model.to(device)
torch.cuda.empty_cache()
video_flow = model.encode_first_stage(video_flow, None)
video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
torch.cuda.empty_cache()
model.to(device)
value_dict = {
"prompt": text,
"negative_prompt": "",
"num_frames": torch.tensor(T).unsqueeze(0),
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
if args.image2video and image is not None:
c["concat"] = image
uc["concat"] = image
for index in range(args.num_samples_per_prompt):
if cnt > 0:
args.seed = np.random.randint(1e6)
set_random_seed(args.seed)
# reload model on GPU
model.to(device)
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
video_flow=video_flow,
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
# Unload the model from GPU to save GPU memory
model.to("cpu")
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
# Decode latent serial to save GPU memory
recons = []
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
else:
start_frame, end_frame = i * 2 + 1, i * 2 + 3
if i == loop_num - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = args.output_dir
name = str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:60] + f"_{index}_seed{args.seed}"
if args.flow_from_prompt:
name = Path(flow_files).stem
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(
samples,
save_path,
name,
fps=args.sampling_fps,
traj_points=locals().get("points", None),
prompt=text,
)
del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
gc.collect()