in fairscale/experimental/nn/offload.py [0:0]
def backward(ctx, *grad_outputs): # type: ignore # pragma: no cover
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.inputs
model_instance = ctx.model_instance
for i, need_grad in enumerate(ctx.grad_requirements):
inputs[i].requires_grad = need_grad
all_grads = [grad_outputs]
for model_shard, activation in zip(
reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
):
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_load"):
# Move the activation to the GPU.
activation = tuple([a.cuda() for a in list(activation)])
# Move the model shard to the GPU.
model_shard.backward_load()
# Store the BW pass state.
bwd_rng_state = torch.get_rng_state()
# TODO(anj-s): Why detach inputs?
activation = torch.utils.checkpoint.detach_variable(activation)
# Get the last gradient calculation.
final_grads = all_grads[-1]
if isinstance(activation, torch.Tensor):
activation = (activation,)
if isinstance(final_grads, torch.Tensor):
final_grads = (final_grads,)
# Iterate through all the inputs/outputs of a shard (there could be multiple).
chunked_grad_list: List[Any] = []
# Chunk the activation and grad based on the number of microbatches that are set.
for chunked_activation, chunked_grad in zip(
torch.chunk(*activation, model_instance._num_microbatches), # type: ignore
torch.chunk(*final_grads, model_instance._num_microbatches), # type: ignore
):
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_rng_state)
if isinstance(chunked_activation, torch.Tensor):
chunked_activation = (chunked_activation,) # type: ignore
if isinstance(chunked_grad, torch.Tensor):
chunked_grad = (chunked_grad,) # type: ignore
# Since we need a grad value of a non leaf element we need to set these properties.
for a in chunked_activation:
if a.dtype == torch.long:
continue
a.requires_grad = True
a.retain_grad()
with torch.autograd.profiler.record_function(
"fairscale.experimental.nn.offload:forward_pass_with_enable_grad"
):
with torch.enable_grad():
# calculate the output of the last shard wrt to the stored activation at the slice boundary.
outputs = model_shard(*chunked_activation)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_rng_state)
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_pass"):
torch.autograd.backward(outputs, chunked_grad)
intermediate_grads = []
for a in chunked_activation:
if a.grad is not None:
intermediate_grads.append(a.grad)
if None not in intermediate_grads:
chunked_grad_list += intermediate_grads
if chunked_grad_list:
# Append the list of grads to the all_grads list and this should be on the GPU.
all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_drop"):
# Move the shard back to the CPU. This should move all the grad tensors to CPU as well.
# We don't need to move activations since we are using a copy of the tensors on the GPU.
model_shard.backward_drop()
detached_inputs = model_instance._activations[0]
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads