augmentation/parse_config.py (54 lines of code) (raw):
from pathlib import Path
from typing import Union
import nlpaug.augmenter.word as naw
from nlpaug.flow import Sequential, Sometimes
import yaml
from augmenters import MultiVariantBackTranslationAug
from filtering import FilterAugmented, ThresholdAcceptor
from metrics import single_reference_sentence_bleu
def read_config(config_path: Union[Path, str]):
config_path = Path(config_path)
with config_path.open('r') as istream:
return yaml.load(istream, Loader=yaml.FullLoader)
def build_augmentation_pipeline(config_file: Path):
config = read_config(config_file)
flow = build_flow(config["flow"])
return build_filtration(flow, config["filtration"])
def build_flow(flow_config):
augmenters = []
for key, value in flow_config.items():
if key == "settings":
continue
augmenters.append(build_augmenter(key, value))
settings = flow_config['settings']
if not settings['is_random']:
return Sequential(augmenters)
return Sometimes(augmenters, aug_p=settings['aug_p'])
def build_augmenter(key, kwargs):
AUGMENTER_CLASSES = {
"MultiVariantBackTranslationAug": MultiVariantBackTranslationAug,
"SynonymAug": naw.SynonymAug,
"ContextualWordEmbsAug": naw.ContextualWordEmbsAug
}
if 'stopwords' in kwargs:
kwargs['stopwords'] = get_stopwords(kwargs['stopwords'])
return AUGMENTER_CLASSES[key](**kwargs)
def get_stopwords(stopwords):
if isinstance(stopwords, list):
return stopwords
with open(stopwords, 'r') as istream:
return [
l.strip() for l in istream
if l.strip() != ""
]
def build_filtration(flow, filtration_config):
METRICS = {
"bleu": single_reference_sentence_bleu
}
if filtration_config['metric_fn'] is None:
return flow
return FilterAugmented(
flow,
metric_fn=single_reference_sentence_bleu,
metric_acceptor=ThresholdAcceptor(filtration_config['low'], filtration_config['high'])
)