augmentation/augment.py (60 lines of code) (raw):

import argparse as ag from pathlib import Path import json from copy import copy from typing import List, Union, Dict from nlpaug import Augmenter from tqdm import tqdm from parse_config import build_augmentation_pipeline, read_config from utils import find_free_file def parse_args(): parser = ag.ArgumentParser() parser.add_argument('-c', '--config', type=Path) parser.add_argument('-i', '--input-json', type=Path) parser.add_argument('-o', '--output-json', type=Path) parser.add_argument('-t', '--text_key', type=Path) return parser.parse_args() def read_original_data(path: Union[Path, str]): path = Path(path) with path.open('r') as istream: data = json.load(istream) result = [] for k, v in data.items(): v.update({'id': k, 'original': True}) result.append(v) return result def augment_dataset( augmenter: Augmenter, data: List[Dict], text_key: str, out_file: Path, aug_config: Dict, original_key: str='original' ): with find_free_file(out_file).open("x") as ostream: # Dump config so it will be easy to know how data was augmented ostream.write(f"{json.dumps(aug_config)}\n") for entry in tqdm(data): # Save original example: ostream.write(f"{json.dumps(entry)}\n") example = entry[text_key] variants = augmenter.augment(example) if isinstance(variants, str): variants = [variants] for var in variants: new_entry = copy(entry) new_entry.update({ text_key: var, original_key: False }) ostream.write(f"{json.dumps(new_entry)}\n") if __name__ == '__main__': args = parse_args() pipeline = build_augmentation_pipeline(args.config) config = read_config(args.config) data = read_original_data(args.input_json) augment_dataset( augmenter=pipeline, data=data, text_key=args.text_key, out_file=args.output_json, aug_config=config )