in utils.py [0:0]
def backward(ctx, grad_output):
(
losses,
scales,
emissions_graphs,
transitions_graphs,
in_shape,
) = ctx.auxiliary_data
B, T, C = in_shape
input_grad = transitions_grad = None
if ctx.needs_input_grad[0]:
input_grad = torch.empty((B, T, C))
if ctx.needs_input_grad[1]:
transitions_grad = torch.empty((B, C + 1, C))
def process(b):
gtn.backward(losses[b], False)
emissions = emissions_graphs[b]
transitions = transitions_graphs[b]
if input_grad is not None:
grad = emissions.grad().weights_to_numpy()
input_grad[b] = torch.from_numpy(grad).view(1, T, C) * scales[b]
if transitions_grad is not None:
grad = transitions.grad().weights_to_numpy()
transitions_grad[b] = (
torch.from_numpy(grad).view(1, C + 1, C) * scales[b]
)
gtn.parallel_for(process, range(B))
if input_grad is not None:
if grad_output.is_cuda:
input_grad = input_grad.cuda()
input_grad *= grad_output / B
if transitions_grad is not None:
if grad_output.is_cuda:
transitions_grad = transitions_grad.cuda()
transitions_grad = torch.mean(transitions_grad, 0) * grad_output
return (
input_grad,
transitions_grad,
None, # target
None, # reduction
)