def forward()

in fairscale/experimental/nn/offload.py [0:0]


    def forward(ctx: Any, inputs: Any, dummy_input: Any, model_instance: Any) -> Any:
        inputs = inputs if isinstance(inputs, tuple) else (inputs,)

        ctx.inputs = inputs
        ctx.model_instance = model_instance
        # TODO(anj-s): We might need to store this for each boundary activation.
        # Currently we assume all boundary activation inputs require
        ctx.grad_requirements = tuple(x.requires_grad for x in inputs)
        ctx.fwd_rng_state = torch.get_rng_state()

        # List of input activations starting with the given input.
        model_instance._activations = [inputs]
        # Enumerate through layer shards and apply activations from the previous shard.
        for index, layer_shard in enumerate(model_instance.model_slices):
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_load"):
                # Bring in the current activations onto the device.
                model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
                # Bring in the current layer shard onto the device.
                layer_shard.forward_load()

            # Apply the FP and store the activations on the CPU.
            inputs = model_instance._activations[index]
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:no_grad_forward_pass"):
                with torch.no_grad():
                    output_list: List[Any] = []
                    for given_input in inputs:
                        given_input_list = torch.chunk(given_input, model_instance._num_microbatches)
                        given_output_list = []
                        for inputs in given_input_list:
                            output = layer_shard(inputs)
                            given_output_list.append(output)
                        given_output = torch.cat(given_output_list).squeeze(-1)
                        output_list.append(given_output)
                    output = tuple(output_list)

            output = output if isinstance(output, tuple) else (output,)
            with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
                # Move the activation used back for the curent shard back to the CPU.
                model_instance._activations[index] = tuple([a.cpu() for a in list(model_instance._activations[index])])
                # The newly computed activations remain on the GPU ready for the next shard computation.
                model_instance._activations.append(output)
                # Move the layer shard back to the CPU.
                layer_shard.forward_drop()

        # The last instance will lose the gradient function if we move it to the CPU.
        # This is because all grad function are present on the device that ran the FW pass.
        # The last activation remains on the GPU and is the return value of this function.
        # Note that this assumes that the target is also on the GPU which is required for calculating
        # the loss.
        result = model_instance._activations[-1]
        result = [r.cuda() for r in result]
        for r in result:
            r.requires_grad = True
        return result[0] if len(result) == 1 else result