in avhubert/infer_s2s.py [0:0]
def _main(cfg, output_file):
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=output_file,
)
logger = logging.getLogger("hybrid.speech_recognize")
if output_file is not sys.stdout: # also print to stdout
logger.addHandler(logging.StreamHandler(sys.stdout))
utils.import_user_module(cfg.common)
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([cfg.common_eval.path])
models = [model.eval().cuda() for model in models]
saved_cfg.task.modalities = cfg.override.modalities
task = tasks.setup_task(saved_cfg.task)
task.build_tokenizer(saved_cfg.tokenizer)
task.build_bpe(saved_cfg.bpe)
logger.info(cfg)
# Fix seed for stochastic decoding
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
use_cuda = torch.cuda.is_available()
# Set dictionary
dictionary = task.target_dictionary
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
task.cfg.noise_prob = cfg.override.noise_prob
task.cfg.noise_snr = cfg.override.noise_snr
task.cfg.noise_wav = cfg.override.noise_wav
if cfg.override.data is not None:
task.cfg.data = cfg.override.data
if cfg.override.label_dir is not None:
task.cfg.label_dir = cfg.override.label_dir
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
lms = [None]
# Optimize ensemble for generation
for model in chain(models, lms):
if model is None:
continue
if cfg.common.fp16:
model.half()
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(cfg)
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(cfg.dataset.gen_subset),
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(), *[m.max_positions() for m in models]
),
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_shards=cfg.distributed_training.distributed_world_size,
shard_id=cfg.distributed_training.distributed_rank,
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
)
# Initialize generator
if cfg.generation.match_source_len:
logger.warning(
"The option match_source_len is not applicable to speech recognition. Ignoring it."
)
gen_timer = StopwatchMeter()
extra_gen_cls_kwargs = {
"lm_model": lms[0],
"lm_weight": cfg.generation.lm_weight,
}
cfg.generation.score_reference = False #
save_attention_plot = cfg.generation.print_alignment is not None
cfg.generation.print_alignment = None #
generator = task.build_generator(
models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def decode_fn(x):
symbols_ignore = get_symbols_to_strip_from_output(generator)
symbols_ignore.add(dictionary.pad())
if hasattr(task.datasets[cfg.dataset.gen_subset].label_processors[0], 'decode'):
return task.datasets[cfg.dataset.gen_subset].label_processors[0].decode(x, symbols_ignore)
chars = dictionary.string(x, extra_symbols_to_ignore=symbols_ignore)
words = " ".join("".join(chars.split()).replace('|', ' ').split())
return words
num_sentences = 0
has_target = True
wps_meter = TimeMeter()
result_dict = {'utt_id': [], 'ref': [], 'hypo': []}
for sample in progress:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue
prefix_tokens = None
if cfg.generation.prefix_size > 0:
prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
constraints = None
if "constraints" in sample:
constraints = sample["constraints"]
gen_timer.start()
hypos = task.inference_step(
generator,
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i in range(len(sample["id"])):
result_dict['utt_id'].append(sample['utt_id'][i])
ref_sent = decode_fn(sample['target'][i].int().cpu())
result_dict['ref'].append(ref_sent)
best_hypo = hypos[i][0]['tokens'].int().cpu()
hypo_str = decode_fn(best_hypo)
result_dict['hypo'].append(hypo_str)
logger.info(f"\nREF:{ref_sent}\nHYP:{hypo_str}\n")
wps_meter.update(num_generated_tokens)
progress.log({"wps": round(wps_meter.avg)})
num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
logger.info("NOTE: hypothesis and token scores are output in base 2")
logger.info("Recognized {:,} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
yaml_str = OmegaConf.to_yaml(cfg.generation)
fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16)
fid = fid % 1000000
result_fn = f"{cfg.common_eval.results_path}/hypo-{fid}.json"
json.dump(result_dict, open(result_fn, 'w'), indent=4)
n_err, n_total = 0, 0
assert len(result_dict['hypo']) == len(result_dict['ref'])
for hypo, ref in zip(result_dict['hypo'], result_dict['ref']):
hypo, ref = hypo.strip().split(), ref.strip().split()
n_err += editdistance.eval(hypo, ref)
n_total += len(ref)
wer = 100 * n_err / n_total
wer_fn = f"{cfg.common_eval.results_path}/wer.{fid}"
with open(wer_fn, "w") as fo:
fo.write(f"WER: {wer}\n")
fo.write(f"err / num_ref_words = {n_err} / {n_total}\n\n")
fo.write(f"{yaml_str}")
logger.info(f"WER: {wer}%")
return