mlebench/competitions/text-normalization-challenge-english-language/prepare.py (89 lines of code) (raw):
import csv
import zipfile
from pathlib import Path
from sklearn.model_selection import train_test_split
from mlebench.utils import compress, extract, read_csv
def prepare(raw: Path, public: Path, private: Path):
# Extract
extract(raw / "en_test_2.csv.zip", raw) # We only use the 2nd stage test set
extract(raw / "en_train.csv.zip", raw)
extract(raw / "en_sample_submission_2.csv.zip", raw)
# Create train and test splits from train set
old_train = read_csv(raw / "en_train.csv")
# We split so that we don't share any sentence_ids between train and test
# This gives us len(new_train) = 8924976 and len(answers) = 993465
# Original was len(old_train) = 9918441 and len(old_test) = 956046
unique_sentence_ids = old_train["sentence_id"].unique()
train_sentence_ids, test_sentence_ids = train_test_split(
unique_sentence_ids, test_size=0.1, random_state=0
)
new_train = old_train[old_train["sentence_id"].isin(train_sentence_ids)]
answers = old_train[old_train["sentence_id"].isin(test_sentence_ids)]
assert set(new_train["sentence_id"]).isdisjoint(
set(answers["sentence_id"])
), f"sentence_id is not disjoint between train and test sets"
# "sentence_id" counts need to be reset for new_train and answers
new_train_id_mapping = {
old_id: new_id for new_id, old_id in enumerate(new_train["sentence_id"].unique())
}
new_train.loc[:, "sentence_id"] = new_train["sentence_id"].map(new_train_id_mapping)
answers_id_mapping = {
old_id: new_id for new_id, old_id in enumerate(answers["sentence_id"].unique())
}
answers.loc[:, "sentence_id"] = answers["sentence_id"].map(answers_id_mapping)
# Create new test set
new_test = answers.drop(["after", "class"], axis=1).copy()
# Reformat answers to match sample submission format
answers = answers[["sentence_id", "token_id", "after"]].copy()
answers["id"] = answers["sentence_id"].astype(str) + "_" + answers["token_id"].astype(str)
answers = answers[["id", "after"]]
# Create sample submission
sample_submission = new_test[["sentence_id", "token_id", "before"]].copy()
sample_submission["id"] = (
sample_submission["sentence_id"].astype(str)
+ "_"
+ sample_submission["token_id"].astype(str)
)
sample_submission["after"] = sample_submission["before"]
sample_submission = sample_submission[["id", "after"]]
# Checks
assert new_train.columns.tolist() == [
"sentence_id",
"token_id",
"class",
"before",
"after",
], f"new_train.columns.tolist() == {new_train.columns.tolist()}"
assert new_test.columns.tolist() == [
"sentence_id",
"token_id",
"before",
], f"new_test.columns.tolist() == {new_test.columns.tolist()}"
assert sample_submission.columns.tolist() == [
"id",
"after",
], f"sample_submission.columns.tolist() == {sample_submission.columns.tolist()}"
assert answers.columns.tolist() == [
"id",
"after",
], f"answers.columns.tolist() == {answers.columns.tolist()}"
assert len(new_test) + len(new_train) == len(
old_train
), f"New train and test sets do not sum to old train set, got {len(new_test) + len(new_train)} and {len(old_train)}"
# Write CSVs
answers.to_csv(
private / "answers.csv", index=False, quotechar='"', quoting=csv.QUOTE_NONNUMERIC
)
sample_submission.to_csv(
private / "sample_submission.csv", index=False, quotechar='"', quoting=csv.QUOTE_NONNUMERIC
)
new_train.to_csv(
public / "en_train.csv", index=False, quotechar='"', quoting=csv.QUOTE_NONNUMERIC
)
new_test.to_csv(
public / "en_test_2.csv", index=False, quotechar='"', quoting=csv.QUOTE_NONNUMERIC
)
sample_submission.to_csv(
public / "en_sample_submission_2.csv",
index=False,
quotechar='"',
quoting=csv.QUOTE_NONNUMERIC,
)
# Zip up
with zipfile.ZipFile(public / "en_train.csv.zip", "w") as zipf:
zipf.write(public / "en_train.csv", arcname="en_train.csv")
with zipfile.ZipFile(public / "en_test_2.csv.zip", "w") as zipf:
zipf.write(public / "en_test_2.csv", arcname="en_test_2.csv")
with zipfile.ZipFile(public / "en_sample_submission_2.csv.zip", "w") as zipf:
zipf.write(public / "en_sample_submission_2.csv", arcname="en_sample_submission_2.csv")
(public / "en_train.csv").unlink()
(public / "en_test_2.csv").unlink()
(public / "en_sample_submission_2.csv").unlink()