in threestudio/scripts/train_dreambooth.py [0:0]
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
instance_image = exif_transpose(instance_image)
if self.class_labels_conditioning=="camera_pose":
instance_camera_pose = np.load(str(self.instance_images_path[index % self.num_instance_images]).replace("png", "npy"))
example["instance_camera_pose"] = torch.tensor(instance_camera_pose).reshape(1, -1)
if self.use_view_dependent_prompt:
angle = float(os.path.basename(self.instance_images_path[index % self.num_instance_images])[4:-4])
if angle < 45 or angle >= 315:
view = "front view"
elif 45 <= angle < 135 or 225 <= angle < 315:
view = "side view"
else:
view = "back view"
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
if self.encoder_hidden_states is not None:
example["instance_prompt_ids"] = self.encoder_hidden_states
else:
# view-dependent prompt
if self.use_view_dependent_prompt:
instance_prompt = self.instance_prompt + f", {view}"
else:
instance_prompt = self.instance_prompt
text_inputs = tokenize_prompt(
self.tokenizer, instance_prompt, tokenizer_max_length=self.tokenizer_max_length
)
example["instance_prompt_ids"] = text_inputs.input_ids
example["instance_attention_mask"] = text_inputs.attention_mask
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else:
class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
)
example["class_prompt_ids"] = class_text_inputs.input_ids
example["class_attention_mask"] = class_text_inputs.attention_mask
return example