augmentation/augment_bert.py (50 lines of code) (raw):
import pandas as pd
import nlpaug.augmenter.word as naw
from tqdm import tqdm
from pathlib import Path
import json
import argparse
from parse_config import get_stopwords
from utils import find_free_file
def augment_dataframe(original, augmenter, input_key, batch_size=1):
augmented = original.copy()
augmented.rename({'origin': 'obtained_from'}, axis=1)
if batch_size == 1:
for (_, orig_row), (_, aug_row) in tqdm(zip(original.iterrows(), augmented.iterrows())):
aug_row[input_key] = augmenter.augment(orig_row[input_key])
else:
for i in tqdm(range(0, len(original), batch_size)):
batch = list(original[input_key].iloc[i:i+batch_size])
for j, a in enumerate(augmenter.augment(batch)):
augmented[input_key].iloc[i+j] = a
augmented['origin'] = 'augmented'
return augmented
def save_dataframe(dataframe, path):
save_path = find_free_file(path)
dataframe.to_csv(save_path)
def save_config(config, path):
with find_free_file(path).open('w') as ostream:
ostream.write(json.dumps(config))
def main(args):
df = pd.read_csv(args.input_csv, index_col=0)
kwargs = {
'top_k': 10,
'action': 'insert',
'model_path': args.bert_path,
'aug_min': 2,
'aug_max': 4,
'stopwords': get_stopwords(args.stopwords)
}
augmenter = naw.ContextualWordEmbsAug(device=args.device, **kwargs)
augmented = augment_dataframe(df, augmenter, args.text_key, batch_size=2)
augmented.to_csv(args.output_csv)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input-csv', type=Path)
parser.add_argument('-o', '--output-csv', type=Path)
parser.add_argument('-b', '--bert-path', type=str)
parser.add_argument('-d', '--device', type=str)
parser.add_argument('-s', '--stopwords', type=str)
parser.add_argument('-t', '--text_key', type=str)
args = parser.parse_args()
main(args)