def backward()

in transducer.py [0:0]


    def backward(ctx, grad_output):
        losses, emissions_graphs, transitions = ctx.graphs
        scales = ctx.scales
        B, T, C = ctx.input_shape
        calc_emissions = ctx.needs_input_grad[0]
        input_grad = torch.empty((B, T, C)) if calc_emissions else None

        def process(b):
            scale = make_scalar_graph(scales[b])
            gtn.backward(losses[b], scale)
            emissions = emissions_graphs[b]
            if calc_emissions:
                grad = emissions.grad().weights_to_numpy()
                input_grad[b] = torch.tensor(grad).view(1, T, C)

        gtn.parallel_for(process, range(B))

        if calc_emissions:
            input_grad = input_grad.to(grad_output.device)
            input_grad *= grad_output / B

        if ctx.needs_input_grad[4]:
            grad = transitions.grad().weights_to_numpy()
            transition_grad = torch.tensor(grad).to(grad_output.device)
            transition_grad *= grad_output / B
        else:
            transition_grad = None

        return (
            input_grad,
            None,  # target
            None,  # tokens
            None,  # lex
            transition_grad,  # transition params
            None,  # transitions graph
            None,
        )