in utils.py [0:0]
def forward(ctx, inputs, transitions, targets, reduction="none"):
B, T, C = inputs.shape
losses = [None] * B
scales = [None] * B
emissions_graphs = [None] * B
transitions_graphs = [None] * B
calc_trans_grad = transitions.requires_grad
transitions = transitions.cpu() # avoid multiple cuda -> cpu copies
def process(b):
# create emission graph
g_emissions = gtn.linear_graph(T, C, inputs.requires_grad)
cpu_data = inputs[b].cpu().contiguous()
g_emissions.set_weights(cpu_data.data_ptr())
# create transition graph
g_transitions = ASGLossFunction.create_transitions_graph(
transitions, calc_trans_grad
)
# create force align criterion graph
g_fal = ASGLossFunction.create_force_align_graph(targets[b])
# compose the graphs
g_fal_fwd = gtn.forward_score(
gtn.intersect(gtn.intersect(g_fal, g_transitions), g_emissions)
)
g_fcc_fwd = gtn.forward_score(gtn.intersect(g_emissions, g_transitions))
g_loss = gtn.subtract(g_fcc_fwd, g_fal_fwd)
scale = 1.0
if reduction == "mean":
L = len(targets[b])
scale = 1.0 / L if L > 0 else scale
elif reduction != "none":
raise ValueError("invalid value for reduction '" + str(reduction) + "'")
# Save for backward:
losses[b] = g_loss
scales[b] = scale
emissions_graphs[b] = g_emissions
transitions_graphs[b] = g_transitions
gtn.parallel_for(process, range(B))
ctx.auxiliary_data = (
losses,
scales,
emissions_graphs,
transitions_graphs,
inputs.shape,
)
loss = torch.tensor([losses[b].item() * scales[b] for b in range(B)])
return torch.mean(loss.cuda() if inputs.is_cuda else loss)