in scripts/create_lama_uhn.py [0:0]
def main(args):
srcdir = args.srcdir
assert os.path.isdir(srcdir)
srcdir = srcdir.rstrip("/")
tgtdir = srcdir + "_UHN"
if not os.path.exists(tgtdir):
os.mkdir(tgtdir)
uhn_filters = []
if "string_match" in args.filters:
uhn_filters.append(
StringMatchFilter(do_lower_case=args.string_match_do_lowercase)
)
if "person_name" in args.filters:
uhn_filters.append(
PersonNameFilter(
bert_name=args.person_name_bert, top_k=args.person_name_top_k
)
)
for filename in tqdm.tqdm(sorted(os.listdir(srcdir))):
infile = os.path.join(srcdir, filename)
outfile = os.path.join(tgtdir, filename)
with open(infile) as handle:
queries = [json.loads(line) for line in handle]
for uhn_filter in uhn_filters:
queries = uhn_filter.filter(queries)
with open(outfile, "w") as handle:
for query in queries:
handle.write(json.dumps(query) + "\n")