in data.py [0:0]
def get_dataset(accelerator=None, dataset_id=None, cache_dir=None, num_proc=4):
if accelerator is None:
accelerator = accelerate.PartialState()
keep_cols = set(["Color", "image", "Lighting", "Lighting Type", "Composition"])
dataset = load_dataset(dataset_id, split="train", cache_dir=cache_dir)
all_cols = set(dataset.features.keys())
dataset = dataset.remove_columns(list(all_cols - keep_cols))
with accelerator.main_process_first():
dataset = dataset.shuffle(seed=2025)
dataset = dataset.with_transform(preprocess_batch)
return dataset