in scripts/build_transitions.py [0:0]
def build_graph(ngrams, disable_backoff=False):
graph = gtn.Graph(False)
ngram = len(ngrams)
state_to_node = {}
def get_node(state):
node = state_to_node.get(state, None)
if node is not None:
return node
start = state == tuple([START_IDX]) if ngram > 1 else True
end = state == tuple([END_IDX]) if ngram > 1 else True
node = graph.add_node(start, end)
state_to_node[state] = node
if not disable_backoff and not end:
# Add back off when adding node
for n in range(1, len(state) + 1):
back_off_node = state_to_node.get(state[n:], None)
# Epsilon transition to the back-off state
if back_off_node is not None:
graph.add_arc(node, back_off_node, gtn.epsilon)
break
return node
for grams in ngrams:
for gram in grams:
istate, ostate = gram[0:-1], gram[len(gram) - ngram + 1 :]
inode = get_node(istate)
if END_IDX not in gram[1:] and gram[1:] not in state_to_node:
raise ValueError(
"Ill formed counts: if (x, y_1, ..., y_{n-1}) is above"
"the n-gram threshold, then (y_1, ..., y_{n-1}) must be"
"above the (n-1)-gram threshold"
)
if END_IDX in ostate:
# merge all state having </s> into one as the final graph generated
# will be similar
ostate = tuple([END_IDX])
onode = get_node(ostate)
# p(gram[-1] | gram[:-1])
graph.add_arc(
inode, onode, gtn.epsilon if gram[-1] == END_IDX else gram[-1]
)
return graph