def backward()

in utils.py [0:0]


    def backward(ctx, grad_output):
        (
            losses,
            scales,
            emissions_graphs,
            transitions_graphs,
            in_shape,
        ) = ctx.auxiliary_data
        B, T, C = in_shape
        input_grad = transitions_grad = None
        if ctx.needs_input_grad[0]:
            input_grad = torch.empty((B, T, C))
        if ctx.needs_input_grad[1]:
            transitions_grad = torch.empty((B, C + 1, C))

        def process(b):
            gtn.backward(losses[b], False)
            emissions = emissions_graphs[b]
            transitions = transitions_graphs[b]
            if input_grad is not None:
                grad = emissions.grad().weights_to_numpy()
                input_grad[b] = torch.from_numpy(grad).view(1, T, C) * scales[b]
            if transitions_grad is not None:
                grad = transitions.grad().weights_to_numpy()
                transitions_grad[b] = (
                    torch.from_numpy(grad).view(1, C + 1, C) * scales[b]
                )

        gtn.parallel_for(process, range(B))
        if input_grad is not None:
            if grad_output.is_cuda:
                input_grad = input_grad.cuda()
            input_grad *= grad_output / B
        if transitions_grad is not None:
            if grad_output.is_cuda:
                transitions_grad = transitions_grad.cuda()

            transitions_grad = torch.mean(transitions_grad, 0) * grad_output
        return (
            input_grad,
            transitions_grad,
            None,  # target
            None,  # reduction
        )