def backward()

in python/lltm_baseline.py [0:0]


    def backward(ctx, grad_h, grad_cell):
        X, weights, input_gate, output_gate, old_cell = ctx.saved_variables[:5]
        new_cell, candidate_cell, gate_weights = ctx.saved_variables[5:]

        d_input = d_weights = d_bias = d_old_h = d_old_cell = None

        d_output_gate = torch.tanh(new_cell) * grad_h
        d_tanh_new_cell = output_gate * grad_h
        d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell

        d_old_cell = d_new_cell
        d_candidate_cell = input_gate * d_new_cell
        d_input_gate = candidate_cell * d_new_cell

        gates = gate_weights.chunk(3, dim=1)
        d_input_gate *= d_sigmoid(gates[0])
        d_output_gate *= d_sigmoid(gates[1])
        d_candidate_cell *= d_elu(gates[2])

        d_gates = torch.cat(
            [d_input_gate, d_output_gate, d_candidate_cell], dim=1)

        if ctx.needs_input_grad[1]:
            d_weights = d_gates.t().mm(X)
        if ctx.needs_input_grad[2]:
            d_bias = d_gates.sum(dim=0, keepdim=True)
        if ctx.needs_input_grad[3] or ctx.needs_input_grad[4]:
            d_X = d_gates.mm(weights)
            state_size = grad_h.shape[1]
            d_old_h, d_input = d_X[:, :state_size], d_X[:, state_size:]

        return d_input, d_weights, d_bias, d_old_h, d_old_cell