in scripts/compute_corrupt.py [0:0]
def prepare_corruptions(args, pid=0, scheme="rand_utt"):
true_df = pd.read_csv(args.corrupt_pre + "true_response.csv")
seq_df = pd.read_csv(args.corrupt_pre + "seq2seq.csv")
back_trans = pd.read_csv(args.corrupt_pre + "backtranslate.csv")
print("[{}] loaded data".format(pid))
## Semantic Corruption
if scheme == "rand_utt":
# NCE Random Utterance
rand_utt_rows = []
# pb_c = tqdm(total=len(true_df))
for i, row in true_df.iterrows():
dial_ids = list(true_df["dialog_id"].unique())
dial_ids.remove(row["dialog_id"])
other_dial = random.choice(dial_ids)
sampled_resp = (
true_df[true_df["dialog_id"] == other_dial]
.sample(n=1)["true_response"]
.values[0]
)
row = {
"dialog_id": row["dialog_id"],
"context_id": row["context_id"],
"context": row["context"],
"context_hash": row["context_hash"] if "context_hash" in row else 1,
"rand_utt": sampled_resp,
}
rand_utt_rows.append(row)
# pb_c.update(1)
rand_utt_df = pd.DataFrame(rand_utt_rows)
rand_utt_df.to_csv(args.corrupt_pre + "rand_utt_{}.csv".format(pid))
# pb_c.close()
if scheme == "model_false":
# NCE Random Model Response
model_false_rows = []
# pb_c = tqdm(total=len(true_df))
for i, row in seq_df.iterrows():
dial_ids = list(true_df["dialog_id"].unique())
dial_ids.remove(row["dialog_id"])
other_dial = random.choice(dial_ids)
sampled_resp = (
seq_df[seq_df["dialog_id"] == other_dial]
.sample(n=1)["seq2seq"]
.values[0]
)
row = {
"dialog_id": row["dialog_id"],
"context_id": row["context_id"],
"context": row["context"],
"context_hash": row["context_hash"] if "context_hash" in row else 1,
"model_false": sampled_resp,
}
model_false_rows.append(row)
# pb_c.update(1)
model_false_df = pd.DataFrame(model_false_rows)
model_false_df.to_csv(args.corrupt_pre + "model_false_{}.csv".format(pid))
if scheme == "rand_back":
# NCE Random Backtranslation response
rand_back_rows = []
# pb_c = tqdm(total=len(true_df))
for i, row in back_trans.iterrows():
dial_ids = list(true_df["dialog_id"].unique())
dial_ids.remove(row["dialog_id"])
other_dial = random.choice(dial_ids)
sampled_resp = (
back_trans[back_trans["dialog_id"] == other_dial]
.sample(n=1)["backtranslate"]
.values[0]
)
row = {
"dialog_id": row["dialog_id"],
"context_id": row["context_id"],
"context": row["context"],
"context_hash": row["context_hash"] if "context_hash" in row else 1,
"rand_back": sampled_resp,
}
rand_back_rows.append(row)
# pb_c.update(1)
rand_back_df = pd.DataFrame(rand_back_rows)
rand_back_df.to_csv(args.corrupt_pre + "rand_back_{}.csv".format(pid))
## Syntactic Corruption
if scheme == "word_drop":
# NCE Random Drop
nce_drop_rows = []
# pb_c = tqdm(total=len(true_df))
for i, row in true_df.iterrows():
response = row["true_response"]
words = response.split(" ")
drop_word_pos = []
for wi, word in enumerate(words):
flip = random.uniform(0, 1)
if flip <= args.drop_per and word not in ["[CLS]", "[SEP]"]:
drop_word_pos.append(wi)
# import ipdb; ipdb.set_trace()
response = [r for i, r in enumerate(words) if i not in drop_word_pos]
if len(response) == 0:
response = random.sample(words, len(words))
if len(response) == 0:
print("response zero")
response = " ".join(response)
if len(response.strip()) == 0:
response = "word"
row = {
"dialog_id": row["dialog_id"],
"context_id": row["context_id"],
"context": row["context"],
"context_hash": row["context_hash"] if "context_hash" in row else 1,
"word_drop": response,
}
nce_drop_rows.append(row)
# pb_c.update(1)
nce_drop_df = pd.DataFrame(nce_drop_rows)
nce_drop_df.to_csv(args.corrupt_pre + "word_drop_{}.csv".format(pid))
# pb_c.close()
if scheme == "rand_word_drop":
# NCE Random Word Drop
nce_rand_drop_rows = []
# pb_c = tqdm(total=len(true_df))
for i, row in true_df.iterrows():
dial_ids = list(true_df["dialog_id"].unique())
dial_ids.remove(row["dialog_id"])
other_dial = random.choice(dial_ids)
response = (
true_df[true_df["dialog_id"] == other_dial]
.sample(n=1)["true_response"]
.values[0]
)
words = response.split(" ")
drop_word_pos = []
for wi, word in enumerate(words):
flip = random.uniform(0, 1)
if flip <= args.drop_per and word not in ["[CLS]", "[SEP]"]:
drop_word_pos.append(wi)
# import ipdb; ipdb.set_trace()
response = [r for i, r in enumerate(words) if i not in drop_word_pos]
if len(response) == 0:
response = random.sample(words, len(words))
if len(response) == 0:
print("response zero")
response = " ".join(response)
if len(response.strip()) == 0:
response = "word"
row = {
"dialog_id": row["dialog_id"],
"context_id": row["context_id"],
"context": row["context"],
"context_hash": row["context_hash"] if "context_hash" in row else 1,
"rand_word_drop": response,
}
nce_rand_drop_rows.append(row)
# pb_c.update(1)
nce_rand_drop_df = pd.DataFrame(nce_rand_drop_rows)
nce_rand_drop_df.to_csv(args.corrupt_pre + "rand_word_drop_{}.csv".format(pid))
if scheme == "word_order":
# NCE Change word order
nce_order_rows = []
# pb_c = tqdm(total=len(true_df))
for i, row in true_df.iterrows():
response = row["true_response"]
words = response.split(" ")
response = random.sample(words, len(words))
response = " ".join(response)
row = {
"dialog_id": row["dialog_id"],
"context_id": row["context_id"],
"context": row["context"],
"context_hash": row["context_hash"] if "context_hash" in row else 1,
"word_order": response,
}
nce_order_rows.append(row)
# pb_c.update(1)
nce_order_df = pd.DataFrame(nce_order_rows)
nce_order_df.to_csv(args.corrupt_pre + "word_order_{}.csv".format(pid))
# pb_c.close()
if scheme == "word_repeat":
# choose a random word in the sentence and start repeating from that word
# in order to mimic "i have have have ..." common seq2seq behaviour
# NCE Change word order
nce_repeat_rows = []
# pb_c = tqdm(total=len(true_df))
for i, row in true_df.iterrows():
response = row["true_response"]
words = response.split(" ")
repeat_word = random.choice(words[:2])
repeat_word_indx = words.index(repeat_word)
response = words[:repeat_word_indx]
response = response + [repeat_word] * (len(words) - len(response))
response = " ".join(response)
row = {
"dialog_id": row["dialog_id"],
"context_id": row["context_id"],
"context": row["context"],
"context_hash": row["context_hash"] if "context_hash" in row else 1,
"word_repeat": response,
}
nce_repeat_rows.append(row)
# pb_c.update(1)
nce_repeat_df = pd.DataFrame(nce_repeat_rows)
nce_repeat_df.to_csv(args.corrupt_pre + "word_repeat_{}.csv".format(pid))
## Corrupting the context
if scheme == "corrupt_context":
# corrupt the context entirely
# NCE Change word order
tokenizer = BertTokenizer.from_pretrained(load_path)
def get_sent_len(context):
X = [tokenizer.tokenize(sent) for sent in context]
X = [tokenizer.convert_tokens_to_ids(sent) for sent in X]
X = [word for sent in X for word in sent]
X = tokenizer.build_inputs_with_special_tokens(X)
return len(X)
nce_cont_cor_rows = []
context_sents = [row["context"].split("\n") for i, row in true_df.iterrows()]
context_sents = [y for x in context_sents for y in x]
# pb_c = tqdm(total=len(true_df))
for i, row in true_df.iterrows():
context = row["context"]
context_len = len(row["context"].split("\n"))
while True:
random_context = random.sample(context_sents, context_len)
if get_sent_len(random_context) < 512:
break
response = row["true_response"]
row = {
"dialog_id": row["dialog_id"],
"context_id": row["context_id"],
"context": "\n".join(random_context),
"context_hash": row["context_hash"] if "context_hash" in row else 1,
"corrupt_context": response,
}
nce_cont_cor_rows.append(row)
# pb_c.update(1)
nce_cont_cor_df = pd.DataFrame(nce_cont_cor_rows)
nce_cont_cor_df.to_csv(args.corrupt_pre + "corrupt_context_{}.csv".format(pid))
# pb_c.close()
return "[{}] {} done".format(pid, scheme)