def get_dataset()

in data.py [0:0]


def get_dataset(accelerator=None, dataset_id=None, cache_dir=None, num_proc=4):
    if accelerator is None:
        accelerator = accelerate.PartialState()

    keep_cols = set(["Color", "image", "Lighting", "Lighting Type", "Composition"])
    dataset = load_dataset(dataset_id, split="train", cache_dir=cache_dir)
    all_cols = set(dataset.features.keys())
    dataset = dataset.remove_columns(list(all_cols - keep_cols))

    with accelerator.main_process_first():
        dataset = dataset.shuffle(seed=2025)
        dataset = dataset.with_transform(preprocess_batch)
    return dataset