in pytorch_translate/rescoring/rescorer.py [0:0]
def __init__(self, args, forward_task=None, models=None):
"""models = {'l2r_model': {'model': model, 'task': task}, ...}"""
self.args = args
if models is None:
models = {}
self.l2r_model_scorer = None
if args.l2r_model_path or models.get("l2r_model", None):
self.l2r_model_scorer = SimpleModelScorer(
args, args.l2r_model_path, models.get("l2r_model", None), forward_task
)
self.r2l_model_scorer = None
if args.r2l_model_path or models.get("r2l_model", None):
self.r2l_model_scorer = R2LModelScorer(
args, args.r2l_model_path, models.get("r2l_model", None), forward_task
)
self.reverse_model_scorer = None
if args.reverse_model_path or models.get("reverse_model", None):
self.reverse_model_scorer = ReverseModelScorer(
args,
args.reverse_model_path,
models.get("reverse_model", None),
forward_task,
)
self.lm_scorer = None
if args.lm_model_path or models.get("lm_model", None):
self.lm_scorer = LMScorer(
args, args.lm_model_path, models.get("lm_model", None), forward_task
)
self.cloze_transformer_scorer = None
if args.cloze_transformer_path or models.get("cloze_model", None):
self.cloze_transformer_scorer = SimpleModelScorer(
args,
args.cloze_transformer_path,
models.get("cloze_model", None),
forward_task,
)