in data.py [0:0]
def collate_fn(batch, processor, max_length=800):
images = [sample["image"] for sample in batch]
# Map each field to its corresponding key.
field_map = {
"color": "<COLOR>",
"lighting": "<LIGHTING>",
"lighting_type": "<LIGHTING_TYPE>",
"composition": "<COMPOSITION>",
}
collated = {}
for name, key in field_map.items():
# Create a list of placeholder prompts and extract the actual text from each sample.
prompts = [key] * len(batch)
texts = [sample[key] for sample in batch]
# Tokenize the raw texts.
tokenized = processor.tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
return_token_type_ids=False,
).input_ids
# Process the images along with the placeholder prompts.
processed_inputs = processor(
text=prompts,
images=images,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
# Store the processed inputs and tokenized texts using consistent naming.
collated[f"{name}_inputs"] = processed_inputs
if name == "color":
collated["colors"] = tokenized
elif name == "lighting":
collated["lightings"] = tokenized
elif name == "lighting_type":
collated["lighting_types"] = tokenized
elif name == "composition":
collated["compositions"] = tokenized
return collated