in transducer.py [0:0]
def backward(ctx, grad_output):
losses, emissions_graphs, transitions = ctx.graphs
scales = ctx.scales
B, T, C = ctx.input_shape
calc_emissions = ctx.needs_input_grad[0]
input_grad = torch.empty((B, T, C)) if calc_emissions else None
def process(b):
scale = make_scalar_graph(scales[b])
gtn.backward(losses[b], scale)
emissions = emissions_graphs[b]
if calc_emissions:
grad = emissions.grad().weights_to_numpy()
input_grad[b] = torch.tensor(grad).view(1, T, C)
gtn.parallel_for(process, range(B))
if calc_emissions:
input_grad = input_grad.to(grad_output.device)
input_grad *= grad_output / B
if ctx.needs_input_grad[4]:
grad = transitions.grad().weights_to_numpy()
transition_grad = torch.tensor(grad).to(grad_output.device)
transition_grad *= grad_output / B
else:
transition_grad = None
return (
input_grad,
None, # target
None, # tokens
None, # lex
transition_grad, # transition params
None, # transitions graph
None,
)