def main()

in src/ie/process_ner_data.py [0:0]


def main(input_file, train_file, test_file):
    samples = set()
    trials = set()
    histo = dict()
    line_cnt = 0
    slot_cnt = 0
    train_cnt = 0
    test_cnt = 0
    with open(input_file, "r", encoding="utf-8") as reader:
        with open(train_file, "w", encoding="utf-8") as train_writer:
            with open(test_file, "w", encoding="utf-8") as test_writer:
                for line in reader:
                    line_cnt += 1
                    values = line.strip().split("\t")
                    sample = values[1] + "\t" + values[2]
                    trials.add(values[0])
                    if sample not in samples:
                        samples.add(sample)
                        slots = values[1].split(",")
                        slot_cnt += len(slots)
                        new_sample = transform(slots, values[2])
                        for slot in slots:
                            label = slot.split(":")[2]
                            histo[label] = histo.get(label, 0) + 1
                        if random.random() < test_ratio:
                            test_writer.write(new_sample + "\n")
                            test_cnt += 1
                        else:
                            train_writer.write(new_sample + "\n")
                            train_cnt += 1

    print(f"Train count: {train_cnt}, test count: {test_cnt} ({100 * test_cnt / len(samples):.1f}%)")
    print(f"Lines read: {line_cnt}, slots: {slot_cnt}, samples: {len(samples)}, trials: {len(trials)}")
    print()

    items = list(histo.items())
    items.sort(key=lambda i: i[1], reverse=True)
    print("label".ljust(22, " ") + "count")
    tot = 0
    for k, v in items:
        print(f"{k:21s} {v:5d}")
        tot += v
    print("total".ljust(21, " ") + f"{tot:6d}")