def prepare()

in mlebench/competitions/AI4Code/prepare.py [0:0]


def prepare(raw: Path, public: Path, private: Path):
    train_ancestors_df = read_csv(raw / "train_ancestors.csv")
    # Shuffle the train_ancestors_df to ensure our split is random
    train_ancestors_df = train_ancestors_df.sample(frac=1, random_state=0).reset_index(drop=True)
    new_train_ids, new_test_ids = create_train_test_split(train_ancestors_df, test_size=20000)

    # Copy json files to public
    (public / "train").mkdir(parents=True, exist_ok=True)
    for train_id in tqdm(new_train_ids, desc="Copying train json files"):
        shutil.copy(raw / "train" / f"{train_id}.json", public / "train" / f"{train_id}.json")
    (public / "test").mkdir(parents=True, exist_ok=True)
    for test_id in tqdm(new_test_ids, desc="Copying test json files"):
        shutil.copy(raw / "train" / f"{test_id}.json", public / "test" / f"{test_id}.json")

    # Generate answers for train and test
    train_orders = read_csv(raw / "train_orders.csv")
    # Answers for new train
    train_orders_new = train_orders[train_orders["id"].isin(new_train_ids)]
    train_orders_new.to_csv(public / "train_orders.csv", index=False)
    # Answers for new test
    test_orders_new = train_orders[train_orders["id"].isin(new_test_ids)]
    test_orders_new.to_csv(private / "test_orders.csv", index=False)

    # Make new train_ancestors.csv, excluding the new_test_ids
    train_ancestors_df = train_ancestors_df[~train_ancestors_df["id"].isin(new_test_ids)]
    train_ancestors_df.to_csv(public / "train_ancestors.csv", index=False)

    # Create sample submission (use the given order without changing it)
    sample_submission_rows = []
    for sample_id in tqdm(test_orders_new["id"], desc="Creating sample submission"):
        # Get cell order from json file
        with open(public / "test" / f"{sample_id}.json") as f:
            json_data = json.load(f)
            cell_order = list(json_data["cell_type"].keys())
        sample_submission_rows.append({"id": sample_id, "cell_order": " ".join(cell_order)})
    sample_submission = pd.DataFrame(sample_submission_rows)
    sample_submission.to_csv(public / "sample_submission.csv", index=False)
    assert len(sample_submission) == len(
        new_test_ids
    ), "Sample submission length does not match test length."