in transducer.py [0:0]
def backward(ctx, grad_output):
output_graphs, input_graphs, kernels = CTX_GRAPHS
B, T, C = ctx.input_shape
kernel_size = ctx.kernel_size
stride = ctx.stride
input_grad = torch.zeros((B, T, C))
deltas = grad_output.cpu().numpy()
def process(b):
for t, window in enumerate(output_graphs[b]):
for c, out in enumerate(window):
delta = make_scalar_graph(deltas[b, t, c])
gtn.backward(out, delta)
grad = (
input_graphs[b][t]
.grad()
.weights_to_numpy()
.reshape(kernel_size, -1)
)
input_grad[b, t * stride : t * stride + kernel_size] += grad
gtn.parallel_for(process, range(B))
if ctx.needs_input_grad[4]:
kernel_grads = [k.grad().weights_to_numpy() for k in kernels]
kernel_grads = np.concatenate(kernel_grads)
kernel_grads = torch.from_numpy(kernel_grads).to(grad_output.device)
else:
kernel_grads = None
return (
input_grad.to(grad_output.device),
None, # kernels
None, # kernel_size
None, # stride
kernel_grads,
None, # viterbi
)