data.py (84 lines of code) (raw):
from datasets import load_dataset
import accelerate
NONE_KEY_MAP = {
"COLOR": ["no_color"],
"LIGHTING": ["no_lighting"],
"LIGHTING_TYPE": ["no_ligting_type"],
"COMPOSITION": ["no_composition"],
}
def preprocess_batch(rows):
"""
Process a batch of examples represented as a dictionary where each key maps to a list of values.
For each key in NONE_KEY_MAP, create a new column with processed details.
Also, conditionally convert images to RGB if they are not already.
"""
n = len(next(iter(rows.values())))
# Prepare storage for new processed columns.
# For each key in NONE_KEY_MAP, the new key is defined as: <{original_key with spaces replaced, uppercase}>
processed_data = {}
for k in NONE_KEY_MAP:
new_key = f"<{k.replace(' ', '_').upper()}>"
processed_data[new_key] = []
# Process each example (by index)
for i in range(n):
# For each key in NONE_KEY_MAP, process the value for the i-th example.
for k in NONE_KEY_MAP:
# If the key is missing, we assume a list of Nones.
value = rows.get(k, [None] * n)[i]
if value:
if isinstance(value, list):
detail = ", ".join(value)
else:
detail = str(value)
else:
default = NONE_KEY_MAP[k][0]
detail = default.replace("no_", "unspecified ").replace("_", " ")
new_key = f"<{k.replace(' ', '_').upper()}>"
processed_data[new_key].append(detail)
# Process the image field if present.
if "image" in rows:
image = rows["image"][i]
if image is not None and hasattr(image, "mode"):
if image.mode != "RGB":
image = image.convert("RGB")
rows["image"][i] = image
# Merge the processed columns into the original batch dictionary.
rows.update(processed_data)
return rows
def get_dataset(accelerator=None, dataset_id=None, cache_dir=None, num_proc=4):
if accelerator is None:
accelerator = accelerate.PartialState()
keep_cols = set(["Color", "image", "Lighting", "Lighting Type", "Composition"])
dataset = load_dataset(dataset_id, split="train", cache_dir=cache_dir)
all_cols = set(dataset.features.keys())
dataset = dataset.remove_columns(list(all_cols - keep_cols))
with accelerator.main_process_first():
dataset = dataset.shuffle(seed=2025)
dataset = dataset.with_transform(preprocess_batch)
return dataset
def collate_fn(batch, processor, max_length=800):
images = [sample["image"] for sample in batch]
# Map each field to its corresponding key.
field_map = {
"color": "<COLOR>",
"lighting": "<LIGHTING>",
"lighting_type": "<LIGHTING_TYPE>",
"composition": "<COMPOSITION>",
}
collated = {}
for name, key in field_map.items():
# Create a list of placeholder prompts and extract the actual text from each sample.
prompts = [key] * len(batch)
texts = [sample[key] for sample in batch]
# Tokenize the raw texts.
tokenized = processor.tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
return_token_type_ids=False,
).input_ids
# Process the images along with the placeholder prompts.
processed_inputs = processor(
text=prompts,
images=images,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
# Store the processed inputs and tokenized texts using consistent naming.
collated[f"{name}_inputs"] = processed_inputs
if name == "color":
collated["colors"] = tokenized
elif name == "lighting":
collated["lightings"] = tokenized
elif name == "lighting_type":
collated["lighting_types"] = tokenized
elif name == "composition":
collated["compositions"] = tokenized
return collated