in models.py [0:0]
def load_criterion(criterion_type, preprocessor, config):
num_tokens = preprocessor.num_tokens
if criterion_type == "asg":
num_replabels = config.get("num_replabels", 0)
use_garbage = config.get("use_garbage", True)
return (
ASG(num_tokens, num_replabels, use_garbage),
num_tokens + num_replabels + int(use_garbage),
)
elif criterion_type == "ctc":
use_pt = config.get("use_pt", True) # use pytorch implementation
return CTC(num_tokens, use_pt), num_tokens + 1 # account for blank
elif criterion_type == "transducer":
blank = config.get("blank", "none")
transitions = config.get("transitions", None)
if transitions is not None:
transitions = gtn.load(transitions)
criterion = transducer.Transducer(
preprocessor.tokens,
preprocessor.graphemes_to_index,
ngram=config.get("ngram", 0),
transitions=transitions,
blank=blank,
allow_repeats=config.get("allow_repeats", True),
reduction="mean",
)
return criterion, num_tokens + int(blank != "none")
else:
raise ValueError(f"Unknown model type {criterion_type}")