in fairscale/experimental/nn/offload.py [0:0]
def forward(ctx: Any, inputs: Any, dummy_input: Any, model_instance: Any) -> Any:
inputs = inputs if isinstance(inputs, tuple) else (inputs,)
ctx.inputs = inputs
ctx.model_instance = model_instance
# TODO(anj-s): We might need to store this for each boundary activation.
# Currently we assume all boundary activation inputs require
ctx.grad_requirements = tuple(x.requires_grad for x in inputs)
ctx.fwd_rng_state = torch.get_rng_state()
# List of input activations starting with the given input.
model_instance._activations = [inputs]
# Enumerate through layer shards and apply activations from the previous shard.
for index, layer_shard in enumerate(model_instance.model_slices):
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_load"):
# Bring in the current activations onto the device.
model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
# Bring in the current layer shard onto the device.
layer_shard.forward_load()
# Apply the FP and store the activations on the CPU.
inputs = model_instance._activations[index]
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:no_grad_forward_pass"):
with torch.no_grad():
output_list: List[Any] = []
for given_input in inputs:
given_input_list = torch.chunk(given_input, model_instance._num_microbatches)
given_output_list = []
for inputs in given_input_list:
output = layer_shard(inputs)
given_output_list.append(output)
given_output = torch.cat(given_output_list).squeeze(-1)
output_list.append(given_output)
output = tuple(output_list)
output = output if isinstance(output, tuple) else (output,)
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
# Move the activation used back for the curent shard back to the CPU.
model_instance._activations[index] = tuple([a.cpu() for a in list(model_instance._activations[index])])
# The newly computed activations remain on the GPU ready for the next shard computation.
model_instance._activations.append(output)
# Move the layer shard back to the CPU.
layer_shard.forward_drop()
# The last instance will lose the gradient function if we move it to the CPU.
# This is because all grad function are present on the device that ran the FW pass.
# The last activation remains on the GPU and is the return value of this function.
# Note that this assumes that the target is also on the GPU which is required for calculating
# the loss.
result = model_instance._activations[-1]
result = [r.cuda() for r in result]
for r in result:
r.requires_grad = True
return result[0] if len(result) == 1 else result