in python/lltm_baseline.py [0:0]
def backward(ctx, grad_h, grad_cell):
X, weights, input_gate, output_gate, old_cell = ctx.saved_variables[:5]
new_cell, candidate_cell, gate_weights = ctx.saved_variables[5:]
d_input = d_weights = d_bias = d_old_h = d_old_cell = None
d_output_gate = torch.tanh(new_cell) * grad_h
d_tanh_new_cell = output_gate * grad_h
d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell
d_old_cell = d_new_cell
d_candidate_cell = input_gate * d_new_cell
d_input_gate = candidate_cell * d_new_cell
gates = gate_weights.chunk(3, dim=1)
d_input_gate *= d_sigmoid(gates[0])
d_output_gate *= d_sigmoid(gates[1])
d_candidate_cell *= d_elu(gates[2])
d_gates = torch.cat(
[d_input_gate, d_output_gate, d_candidate_cell], dim=1)
if ctx.needs_input_grad[1]:
d_weights = d_gates.t().mm(X)
if ctx.needs_input_grad[2]:
d_bias = d_gates.sum(dim=0, keepdim=True)
if ctx.needs_input_grad[3] or ctx.needs_input_grad[4]:
d_X = d_gates.mm(weights)
state_size = grad_h.shape[1]
d_old_h, d_input = d_X[:, :state_size], d_X[:, state_size:]
return d_input, d_weights, d_bias, d_old_h, d_old_cell