def forward()

in src/nanotron/parallel/pipeline_parallel/block.py [0:0]


    def forward(self, **kwargs):
        """Forward pass

        We use a mechanism using TensorPointers to pass Tensors around
        All non Tensor object or TensorPointers are considered pass-through, they are never meant to be communicated cross process

        :param kwargs: Dict[str, Union[TensorPointer, torch.Tensor, Any]]
        :return: Dict[str, Union[TensorPointer, torch.Tensor, Any]
        """
        assert self.module_input_keys == set(
            kwargs.keys()
        ), f"Expected {self.module_input_keys}, got {set(kwargs.keys())}"

        sorted_kwargs = sorted(kwargs.items(), key=get_sort_key(dist.get_rank(self.p2p.pg)))

        # Is the current rank is not the one running the compute
        if dist.get_rank(self.p2p.pg) != self.rank:
            # TODO(kunhao): A better design is to pop this up for both if else branches.
            batch_send_recv = BatchTensorSendRecvState(self.p2p)
            # Send activations from other devices to local rank
            for name, tensor in sorted_kwargs:
                if isinstance(tensor, TensorPointer):
                    # Current rank is neither the rank holding the data nor the rank responsible for computing block
                    continue
                else:
                    assert isinstance(tensor, torch.Tensor)
                    # We need to send the tensor to the rank that actually runs the compute
                    if self.pipeline_state is not None:
                        send_to_pipeline_state_buffer(
                            tensor,
                            to_rank=self.rank,
                            p2p=self.p2p,
                            pipeline_state=self.pipeline_state,
                        )
                        continue

                    if tensor.requires_grad is True:
                        raise ValueError(
                            f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
                        )

                    batch_send_recv.add_send(tensor=tensor, to_rank=self.rank)

            batch_send_recv.flush()
            # Return that the outputs are all in the rank responsible for computing block
            # TODO @thomasw21: Figure out a way to build dummy_input in a generic sense, and remove the necessity to have Dict[str, torch.Tensor] as output
            return {k: TensorPointer(group_rank=self.rank) for k in self.module_output_keys}

        # Recv activations from other devices to local rank
        new_kwargs: Dict[str, torch.Tensor] = {}
        name_to_recv_id = {}
        batch_send_recv = BatchTensorSendRecvState(self.p2p)
        for name, tensor in sorted_kwargs:
            if isinstance(tensor, TensorPointer):
                # Current rank is the one running the compute, we need to query the tensor
                # new_kwargs[name] = recv_tensor(from_rank=tensor.group_rank, p2p=self.p2p)
                # This assumes that prior communication was already done

                # In case of interleaved 1f1b, if this is the second model chunk, then we need to send the previous activations before receiving the current activations
                if isinstance(self.pipeline_state, PipelineTrainBatchState):
                    for _ in range(len(self.pipeline_state.microbatches_activations_to_send)):
                        send_activation = self.pipeline_state.microbatches_activations_to_send.popleft()
                        # Execute
                        send_activation()

                if self.pipeline_state is not None:
                    new_kwargs[name] = recv_from_pipeline_state_buffer(
                        from_rank=tensor.group_rank,
                        p2p=self.p2p,
                        pipeline_state=self.pipeline_state,
                    )
                    continue

                # We don't store result in a buffer
                recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank)
                name_to_recv_id[name] = recv_id
            else:
                new_kwargs[name] = tensor

        # Run receiving communications
        recv_tensors = batch_send_recv.flush()
        assert len(recv_tensors) == len(name_to_recv_id)
        for name, recv_id in name_to_recv_id.items():
            assert name not in new_kwargs
            new_tensor = recv_tensors[recv_id]
            if new_tensor.requires_grad is True:
                raise ValueError(
                    f"Pipeline engine is None and tensor requires grad. Tried receiving a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
                )
            new_kwargs[name] = new_tensor

        output = self.pp_block(**new_kwargs)

        # Helper for functions that return tensors
        if isinstance(output, torch.Tensor):
            assert len(self.module_output_keys) == 1
            output = {next(iter(self.module_output_keys)): output}

        assert isinstance(output, dict), "Modules within a Pipeline Block have to return a Dict[str, torch.Tensor]"
        assert self.module_output_keys == set(
            output.keys()
        ), f"Expected {self.module_output_keys}, got {set(output.keys())}"

        return output