in align/models.py [0:0]
def __init__(self, path='criss/criss-3rd.pt',
args_path='criss/args.pt',
tokenizer='facebook/mbart-large-cc25', device='cpu', distortion=0,
matching_method='a'
):
from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.sequence_generator import EnsembleModel
self.device = device
args = torch.load(args_path)
task = tasks.setup_task(args)
models, _model_args = checkpoint_utils.load_model_ensemble(
path.split(':'),
arg_overrides=eval('{}'),
task=task
)
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
model = model.to(self.device)
self.model = EnsembleModel(models).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.distortion = distortion
self.matching_method = matching_method