in picotron/pipeline_parallel/pipeline_parallel.py [0:0]
def backward(self, input_tensor, output_tensor, output_tensor_grad):
"""
Backward pass for this pipeline stage.
Computes gradients for assigned layers using received gradient from next stage.
"""
if input_tensor is not None: input_tensor.retain_grad()
if output_tensor_grad is None:
output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format)
# torch.autograd.backward will automatically accumulates gradients in the leaves (cf: https://pytorch.org/docs/stable/generated/torch.autograd.backward.html)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False)
return input_tensor.grad if input_tensor is not None else None