def backward()

in transducer.py [0:0]


    def backward(ctx, grad_output):
        output_graphs, input_graphs, kernels = CTX_GRAPHS
        B, T, C = ctx.input_shape
        kernel_size = ctx.kernel_size
        stride = ctx.stride
        input_grad = torch.zeros((B, T, C))
        deltas = grad_output.cpu().numpy()

        def process(b):
            for t, window in enumerate(output_graphs[b]):
                for c, out in enumerate(window):
                    delta = make_scalar_graph(deltas[b, t, c])
                    gtn.backward(out, delta)
                grad = (
                    input_graphs[b][t]
                    .grad()
                    .weights_to_numpy()
                    .reshape(kernel_size, -1)
                )
                input_grad[b, t * stride : t * stride + kernel_size] += grad

        gtn.parallel_for(process, range(B))

        if ctx.needs_input_grad[4]:
            kernel_grads = [k.grad().weights_to_numpy() for k in kernels]
            kernel_grads = np.concatenate(kernel_grads)
            kernel_grads = torch.from_numpy(kernel_grads).to(grad_output.device)
        else:
            kernel_grads = None
        return (
            input_grad.to(grad_output.device),
            None,  # kernels
            None,  # kernel_size
            None,  # stride
            kernel_grads,
            None,  # viterbi
        )