augmentation/backtranslate.py (87 lines of code) (raw):
import argparse as ag
from pathlib import Path
import json
from typing import List, Dict, Generator, Iterable, Union
from copy import copy
from nlpaug import Augmenter
from nlpaug.flow import Sequential
from nlpaug.augmenter.word import BackTranslationAug
import json
from tqdm import tqdm
def parse_args():
parser = ag.ArgumentParser()
parser.add_argument("-i", "--input-json", type=Path)
parser.add_argument("-c", "--chain", type=str)
parser.add_argument("-o", "--output-json", type=Path)
parser.add_argument("-d", "--device", type=str)
parser.add_argument("-t", "--text-key", type=str)
parser.add_argument("-b", "--batch-size", type=int)
args = parser.parse_args()
args.chain = parse_chain(args.chain)
return args
def parse_chain(chain):
return chain.split("-")
def find_free_file(path: Union[str, Path]) -> Path:
path = Path(path) if isinstance(path, str) else path
new_path = path
i = 2
while new_path.exists():
new_path = path.parent / f"{path.stem}_{i}{path.suffix}"
i += 1
return new_path
def augment_dataset(
augmenter: Augmenter,
data: List[Dict],
text_key: str,
out_file: Path,
batch_size: int,
is_original_key: str = 'original'
) -> Generator[Dict, None, None]:
with find_free_file(out_file).open("x") as ostream:
for i in tqdm(range(0, len(data), batch_size)):
batch = data[i:i+batch_size]
examples = [e[text_key] for e in batch]
variants = augmenter.augment(examples)
for entry, var in zip(batch, variants):
new_entry = copy(entry)
new_entry.update({
text_key: var,
is_original_key: False
})
ostream.write(f"{json.dumps(new_entry)}\n")
def build_chain_translation_augmenter(language_chain: List[str], device: str) -> Sequential:
pair_to_model = {
"en-fr": "transformer.wmt14.en-fr",
"en-de": "transformer.wmt19.en-de",
"de-en": "transformer.wmt19.de-en",
"en-ru": "transformer.wmt19.en-ru",
"ru-en": "transformer.wmt19.ru-en"
}
if len(language_chain) <= 2:
raise Exception("Can't backtranslate with less than two languages in a chain")
augmenters = []
for i in range(len(language_chain) - 2):
from_key = f"{language_chain[i]}-{language_chain[i+1]}"
to_key = f"{language_chain[i+1]}-{language_chain[i+2]}"
from_model_name = pair_to_model[from_key]
to_model_name = pair_to_model[to_key]
augmenters.append(
BackTranslationAug(from_model_name=from_model_name,
to_model_name=to_model_name,
device=device)
)
return Sequential(augmenters)
def read_original_data(path: Union[Path, str]):
path = Path(path)
with path.open('r') as istream:
data = json.load(istream)
result = []
for k, v in data.items():
v.update({'id': k, 'original': True})
result.append(v)
return result
if __name__ == "__main__":
args = parse_args()
data = read_original_data(args.input_json)
augmenter = build_chain_translation_augmenter(args.chain, args.device)
augment_dataset(augmenter, data, args.text_key, args.output_json, args.batch_size)