in src/prepare_data.py [0:0]
def create_data(args):
if not os.path.exists(args.output_path):
print(f"Creating directory: {args.output_path}")
os.mkdir(args.output_path)
for split in SPLITs:
print(f"Processing Split: {split}")
srcs = read_src(os.path.join(args.data_path, f"{split}.source"))
tgts = read_tgt(os.path.join(args.data_path, f"{split}.target"))
if args.graph_encoding:
graph_info = read_graph(os.path.join(args.graph_data_path, f"{split}.jsonl"))
if args.shuffle_sentences:
print("shuffling the sentences within a document")
new_srcs = []
for ind, src in enumerate(srcs):
new_docs = []
for doc in src:
doc_sents = sent_tokenize(doc)
random.shuffle(doc_sents)
new_docs.append(" ".join(doc_sents))
new_srcs.append(new_docs)
if ind==0:
print(f"Sample original src:::", srcs[0])
print(f"Sample sentence shuffled src:::", new_srcs[0])
srcs = new_srcs
# Truncate the source
new_srcs = []
for src in srcs:
src = " story_separator_special_tag ".join(src)
src = truncate(src, total_words=args.max_length)
src = src.split(" story_separator_special_tag ")
new_srcs.append(src)
f_src = os.path.join(args.output_path, f"{split}.source")
f_tgt = os.path.join(args.output_path, f"{split}.target")
if args.sentence_level_markers:
if args.graph_encoding:
new_srcs_g = []
for index, src in enumerate(new_srcs):
scores_list = [v["score"] for k,v in graph_info[index].items()]
threshold1 = np.quantile(np.array([scores_list]), 0.33)
threshold2 = np.quantile(np.array([scores_list]), 0.67)
new_docs = []
for i, doc in enumerate(src):
new_doc = []
for j, sent in enumerate(sent_tokenize(doc)):
if j < GRAPH_SENT_LIMIT:
id_ = "d{}_s{}".format(i,j)
score = graph_info[index][id_]["score"]
if score>threshold2:
label = "high"
elif score>threshold1:
label = "medium"
else:
label = "low"
new_doc.append(sent + f" graph score is {label} {SEPARATER_TAG}")
else:
new_doc.append(sent + f" {SEPARATER_TAG}")
new_docs.append(" ".join(new_doc))
new_srcs_g.append(" ".join(new_docs))
srcs = new_srcs_g
else:
#TODO: right now the joining of docs is bad as src tokenize gets wrong
srcs = [" ".join(src) for src in new_srcs]
srcs = [sent_tokenize(src) for src in srcs]
srcs = [f" {SEPARATER_TAG} ".join(src) for src in srcs]
else:
srcs = [f" {SEPARATER_TAG} ".join(src) for src in new_srcs]
with open(f_src, "w") as f:
f.write("\n".join(srcs))
f.flush()
f.close()
with open(f_tgt, "w") as f:
f.write("\n".join(tgts))
f.flush()
f.close()