in crypten/cryptensor.py [0:0]
def backward(self, grad_input=None, top_node=True):
"""
Backpropagates gradient through the computation graph. The function
only maintains the gradients in leaf nodes of the graph.
"""
if self.requires_grad:
with CrypTensor.no_grad(): # disable autograd for backward pass
# in initial backward call, identify all required nodes:
if top_node:
self._identify_required_grads()
# if undefined, set gradient input to one:
if grad_input is None:
if self.nelement() == 1:
grad_input = self.new(torch.ones_like(self.data))
else:
raise RuntimeError(
"grad can be implicitly created only for scalar outputs"
)
# process gradient input:
self.grad_received += 1
if self.grad is None:
self.grad = grad_input # store gradient...
else:
self.grad.add_(grad_input) # ... or accumulate gradient
# if we are in a leaf or if not all parents have backpropagated:
if len(self.children) == 0 or self.grad_received < self.grad_expected:
return # ... do not proceed.
# check that we can actually backpropagate:
if self.grad_fn is None:
raise ValueError("Cannot call backward() before forward().")
# perform backpropagation:
grad = self.grad_fn.backward(self.ctx, self.grad)
differentiable_children = [
x for x in self.children if self.ctx.is_differentiable(x)
]
self.ctx.reset() # free up memory used for context
# call backward function on children:
if not isinstance(grad, (list, tuple)):
grad = (grad,)
assert len(differentiable_children) <= len(
grad
), "number of gradients does not match number of children"
for idx, child in enumerate(differentiable_children):
child.backward(grad_input=grad[idx], top_node=False)
# clean up gradients except in leaf nodes:
if len(differentiable_children) > 0:
self.grad = None
# remove node from graph:
self.children = []
self.grad_expected = 0
self.grad_received = 0