in utils.py [0:0]
def forward(ctx, log_probs, targets, blank_idx=0, reduction="none"):
B, T, C = log_probs.shape
losses = [None] * B
scales = [None] * B
emissions_graphs = [None] * B
def process(b):
# create emission graph
g_emissions = gtn.linear_graph(T, C, log_probs.requires_grad)
cpu_data = log_probs[b].cpu().contiguous()
g_emissions.set_weights(cpu_data.data_ptr())
# create criterion graph
g_criterion = CTCLossFunction.create_ctc_graph(targets[b], blank_idx)
# compose the graphs
g_loss = gtn.negate(
gtn.forward_score(gtn.intersect(g_emissions, g_criterion))
)
scale = 1.0
if reduction == "mean":
L = len(targets[b])
scale = 1.0 / L if L > 0 else scale
elif reduction != "none":
raise ValueError("invalid value for reduction '" + str(reduction) + "'")
# Save for backward:
losses[b] = g_loss
scales[b] = scale
emissions_graphs[b] = g_emissions
gtn.parallel_for(process, range(B))
ctx.auxiliary_data = (losses, scales, emissions_graphs, log_probs.shape)
loss = torch.tensor([losses[b].item() * scales[b] for b in range(B)])
return torch.mean(loss.cuda() if log_probs.is_cuda else loss)