def backward()

in crypten/cryptensor.py [0:0]


    def backward(self, grad_input=None, top_node=True):
        """
        Backpropagates gradient through the computation graph. The function
        only maintains the gradients in leaf nodes of the graph.
        """
        if self.requires_grad:
            with CrypTensor.no_grad():  # disable autograd for backward pass

                # in initial backward call, identify all required nodes:
                if top_node:
                    self._identify_required_grads()

                # if undefined, set gradient input to one:
                if grad_input is None:
                    if self.nelement() == 1:
                        grad_input = self.new(torch.ones_like(self.data))
                    else:
                        raise RuntimeError(
                            "grad can be implicitly created only for scalar outputs"
                        )

                # process gradient input:
                self.grad_received += 1
                if self.grad is None:
                    self.grad = grad_input  # store gradient...
                else:
                    self.grad.add_(grad_input)  # ... or accumulate gradient

                # if we are in a leaf or if not all parents have backpropagated:
                if len(self.children) == 0 or self.grad_received < self.grad_expected:
                    return  # ... do not proceed.

                # check that we can actually backpropagate:
                if self.grad_fn is None:
                    raise ValueError("Cannot call backward() before forward().")

                # perform backpropagation:
                grad = self.grad_fn.backward(self.ctx, self.grad)
                differentiable_children = [
                    x for x in self.children if self.ctx.is_differentiable(x)
                ]
                self.ctx.reset()  # free up memory used for context

                # call backward function on children:
                if not isinstance(grad, (list, tuple)):
                    grad = (grad,)
                assert len(differentiable_children) <= len(
                    grad
                ), "number of gradients does not match number of children"
                for idx, child in enumerate(differentiable_children):
                    child.backward(grad_input=grad[idx], top_node=False)

                # clean up gradients except in leaf nodes:
                if len(differentiable_children) > 0:
                    self.grad = None

                # remove node from graph:
                self.children = []
                self.grad_expected = 0
                self.grad_received = 0