src/nanotron/parallel/pipeline_parallel/engine.py [139:164]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
        # Assign a new state for the current batch
        state = PipelineTrainBatchState()  # TODO: do i need state?
        self.nb_microbatches = nb_microbatches

        outputs = []

        with attach_pipeline_state_to_model(model=model, pipeline_state=state):
            # All forward
            for micro_batch in batch:
                context = self._get_fwd_context(model=model)
                output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_activations_to_send)):
                    send_activation = state.microbatches_activations_to_send.popleft()
                    # Execute
                    send_activation()

                # We make `output` a dict
                if not isinstance(output, dict):
                    output = {"loss": output}

                # Store the loss for each microbatch
                if not isinstance(output["loss"], TensorPointer):
                    output = {k: v.detach() for k, v in output.items()}
                outputs.append(output)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/nanotron/parallel/pipeline_parallel/engine.py [189:214]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
        # Assign a new state for the current batch
        state = PipelineTrainBatchState()
        self.nb_microbatches = nb_microbatches

        outputs = []

        with attach_pipeline_state_to_model(model=model, pipeline_state=state):
            # All forward
            for micro_batch in batch:
                context = self._get_fwd_context(model=model)
                output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_activations_to_send)):
                    send_activation = state.microbatches_activations_to_send.popleft()
                    # Execute
                    send_activation()

                # We make `output` a dict
                if not isinstance(output, dict):
                    output = {"loss": output}

                # Store the loss for each microbatch
                if not isinstance(output["loss"], TensorPointer):
                    output = {k: v.detach() for k, v in output.items()}
                outputs.append(output)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



