in mlebench/competitions/AI4Code/prepare.py [0:0]
def prepare(raw: Path, public: Path, private: Path):
train_ancestors_df = read_csv(raw / "train_ancestors.csv")
# Shuffle the train_ancestors_df to ensure our split is random
train_ancestors_df = train_ancestors_df.sample(frac=1, random_state=0).reset_index(drop=True)
new_train_ids, new_test_ids = create_train_test_split(train_ancestors_df, test_size=20000)
# Copy json files to public
(public / "train").mkdir(parents=True, exist_ok=True)
for train_id in tqdm(new_train_ids, desc="Copying train json files"):
shutil.copy(raw / "train" / f"{train_id}.json", public / "train" / f"{train_id}.json")
(public / "test").mkdir(parents=True, exist_ok=True)
for test_id in tqdm(new_test_ids, desc="Copying test json files"):
shutil.copy(raw / "train" / f"{test_id}.json", public / "test" / f"{test_id}.json")
# Generate answers for train and test
train_orders = read_csv(raw / "train_orders.csv")
# Answers for new train
train_orders_new = train_orders[train_orders["id"].isin(new_train_ids)]
train_orders_new.to_csv(public / "train_orders.csv", index=False)
# Answers for new test
test_orders_new = train_orders[train_orders["id"].isin(new_test_ids)]
test_orders_new.to_csv(private / "test_orders.csv", index=False)
# Make new train_ancestors.csv, excluding the new_test_ids
train_ancestors_df = train_ancestors_df[~train_ancestors_df["id"].isin(new_test_ids)]
train_ancestors_df.to_csv(public / "train_ancestors.csv", index=False)
# Create sample submission (use the given order without changing it)
sample_submission_rows = []
for sample_id in tqdm(test_orders_new["id"], desc="Creating sample submission"):
# Get cell order from json file
with open(public / "test" / f"{sample_id}.json") as f:
json_data = json.load(f)
cell_order = list(json_data["cell_type"].keys())
sample_submission_rows.append({"id": sample_id, "cell_order": " ".join(cell_order)})
sample_submission = pd.DataFrame(sample_submission_rows)
sample_submission.to_csv(public / "sample_submission.csv", index=False)
assert len(sample_submission) == len(
new_test_ids
), "Sample submission length does not match test length."