def collate_fn()

in data.py [0:0]


def collate_fn(batch, processor, max_length=800):
    images = [sample["image"] for sample in batch]

    # Map each field to its corresponding key.
    field_map = {
        "color": "<COLOR>",
        "lighting": "<LIGHTING>",
        "lighting_type": "<LIGHTING_TYPE>",
        "composition": "<COMPOSITION>",
    }

    collated = {}
    for name, key in field_map.items():
        # Create a list of placeholder prompts and extract the actual text from each sample.
        prompts = [key] * len(batch)
        texts = [sample[key] for sample in batch]

        # Tokenize the raw texts.
        tokenized = processor.tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
            return_token_type_ids=False,
        ).input_ids

        # Process the images along with the placeholder prompts.
        processed_inputs = processor(
            text=prompts,
            images=images,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        )

        # Store the processed inputs and tokenized texts using consistent naming.
        collated[f"{name}_inputs"] = processed_inputs
        if name == "color":
            collated["colors"] = tokenized
        elif name == "lighting":
            collated["lightings"] = tokenized
        elif name == "lighting_type":
            collated["lighting_types"] = tokenized
        elif name == "composition":
            collated["compositions"] = tokenized

    return collated