in fairseq_cli/interactive.py [0:0]
def main(cfg: FairseqConfig):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
start_time = time.time()
total_translate_time = 0
utils.import_user_module(cfg.common)
if cfg.interactive.buffer_size < 1:
cfg.interactive.buffer_size = 1
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.batch_size = 1
assert (
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
not cfg.dataset.batch_size
or cfg.dataset.batch_size <= cfg.interactive.buffer_size
), "--batch-size cannot be larger than --buffer-size"
logger.info(cfg)
# Fix seed for stochastic decoding
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
# Setup task, e.g., translation
task = tasks.setup_task(cfg.task)
# Load ensemble
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
task=task,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
for model in models:
if model is None:
continue
if cfg.common.fp16:
model.half()
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(cfg)
# Initialize generator
generator = task.build_generator(models, cfg.generation)
# Handle tokenization and BPE
tokenizer = task.build_tokenizer(cfg.tokenizer)
bpe = task.build_bpe(cfg.bpe)
def encode_fn(x):
if tokenizer is not None:
x = tokenizer.encode(x)
if bpe is not None:
x = bpe.encode(x)
return x
def decode_fn(x):
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)
if cfg.generation.constraints:
logger.warning(
"NOTE: Constrained decoding currently assumes a shared subword vocabulary."
)
if cfg.interactive.buffer_size > 1:
logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
logger.info("NOTE: hypothesis and token scores are output in base 2")
logger.info("Type the input sentence and press return:")
start_id = 0
for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size):
results = []
for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
if constraints is not None:
constraints = constraints.cuda()
sample = {
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
}
translate_start_time = time.time()
translations = task.inference_step(
generator, models, sample, constraints=constraints
)
translate_time = time.time() - translate_start_time
total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)]
if cfg.generation.constraints:
list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
constraints = list_constraints[i]
results.append(
(
start_id + id,
src_tokens_i,
hypos,
{
"constraints": constraints,
"time": translate_time / len(translations),
},
)
)
# sort output to match input order
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
src_str = ""
if src_dict is not None:
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
print("S-{}\t{}".format(id_, src_str))
print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
for constraint in info["constraints"]:
print(
"C-{}\t{}".format(
id_,
tgt_dict.string(constraint, cfg.common_eval.post_process),
)
)
# Process top predictions
for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str,
alignment=hypo["alignment"],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=cfg.common_eval.post_process,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)
detok_hypo_str = decode_fn(hypo_str)
score = hypo["score"] / math.log(2) # convert to base 2
# original hypothesis (after tokenization and BPE)
print("H-{}\t{}\t{}".format(id_, score, hypo_str))
# detokenized hypothesis
print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
print(
"P-{}\t{}".format(
id_,
" ".join(
map(
lambda x: "{:.4f}".format(x),
# convert from base e to base 2
hypo["positional_scores"].div_(math.log(2)).tolist(),
)
),
)
)
if cfg.generation.print_alignment:
alignment_str = " ".join(
["{}-{}".format(src, tgt) for src, tgt in alignment]
)
print("A-{}\t{}".format(id_, alignment_str))
# update running id_ counter
start_id += len(inputs)
logger.info(
"Total time: {:.3f} seconds; translation time: {:.3f}".format(
time.time() - start_time, total_translate_time
)
)