in test.py [0:0]
def test(args):
with open(args.config, "r") as fid:
config = json.load(fid)
if not args.disable_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
dataset = config["data"]["dataset"]
if not os.path.exists(f"datasets/{dataset}.py"):
raise ValueError(f"Unknown dataset {dataset}")
dataset = utils.module_from_file("dataset", f"datasets/{dataset}.py")
input_size = config["data"]["num_features"]
data_path = config["data"]["data_path"]
preprocessor = dataset.Preprocessor(
data_path,
num_features=input_size,
tokens_path=config["data"].get("tokens", None),
lexicon_path=config["data"].get("lexicon", None),
use_words=config["data"].get("use_words", False),
prepend_wordsep=config["data"].get("prepend_wordsep", False),
)
data = dataset.Dataset(data_path, preprocessor, split=args.split)
loader = utils.data_loader(data, config)
criterion, output_size = models.load_criterion(
config.get("criterion_type", "ctc"),
preprocessor,
config.get("criterion", {}),
)
criterion = criterion.to(device)
model = models.load_model(
config["model_type"], input_size, output_size, config["model"]
).to(device)
models.load_from_checkpoint(model, criterion, args.checkpoint_path, args.load_last)
model.eval()
meters = utils.Meters()
for inputs, targets in loader:
outputs = model(inputs.to(device))
meters.loss += criterion(outputs, targets).item() * len(targets)
meters.num_samples += len(targets)
predictions = criterion.viterbi(outputs)
for p, t in zip(predictions, targets):
p, t = preprocessor.tokens_to_text(p), preprocessor.to_text(t)
pw, tw = p.split(preprocessor.wordsep), t.split(preprocessor.wordsep)
pw, tw = list(filter(None, pw)), list(filter(None, tw))
tokens_dist = editdistance.eval(p, t)
words_dist = editdistance.eval(pw, tw)
print("CER: {:.3f}".format(tokens_dist * 100.0 / len(t) if len(t) > 0 else 0))
print("WER: {:.3f}".format(words_dist * 100.0 / len(tw) if len(tw) > 0 else 0))
print("HYP:", "".join(p))
print("REF", "".join(t))
print("=" * 80)
meters.edit_distance_tokens += tokens_dist
meters.edit_distance_words += words_dist
meters.num_tokens += len(t)
meters.num_words += len(tw)
print(
"Loss {:.3f}, CER {:.3f}, WER {:.3f}, ".format(
meters.avg_loss, meters.cer, meters.wer
)
)