in threestudio/scripts/train_dreambooth.py [0:0]
def collate_fn(examples, with_prior_preservation=False):
has_attention_mask = "instance_attention_mask" in examples[0]
has_camera_pose = "instance_camera_pose" in examples[0]
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
if has_attention_mask:
attention_mask = [example["instance_attention_mask"] for example in examples]
if has_camera_pose:
camera_pose = [example["instance_camera_pose"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
if has_attention_mask:
attention_mask += [example["class_attention_mask"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0)
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
if has_attention_mask:
attention_mask = torch.cat(attention_mask, dim=0)
batch["attention_mask"] = attention_mask
if has_camera_pose:
camera_pose = torch.cat(camera_pose, dim=0)
batch["camera_pose"] = camera_pose
return batch