dualpipe/dualpipe.py [276:325]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            return

        chunk_id = self.current_send_b_chunk_id[phase]
        self.current_send_b_chunk_id[phase] += 1
        tensors = self.input_grad_chunks[phase][chunk_id]
        self.input_grad_chunks[phase][chunk_id] = None

        comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group)

    def _commit_and_wait_comm(self) -> None:
        if not self.comm_ops:
            return
        reqs = dist.batch_isend_irecv(self.comm_ops)
        for req in reqs:
            req.wait()
        self.comm_ops = []
        self._free_tensors()

    def step(
        self,
        *inputs: Optional[torch.Tensor],
        num_chunks: int = 0,
        criterion: Optional[Callable] = None,
        labels: List[Optional[torch.Tensor]] = [],
        return_outputs: bool = False,
    ) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
        """
        Execute a training or inference step.

        Arguments:
            *inputs: Module inputs. Required only on the first/last ranks.
            num_chunks: The number of micro-batches.
            criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first/last ranks.
            labels: Labels of the loss function. Required only on the first/last ranks.
                labels on the first rank corresponds to inputs on the last rank.
                labels on the last rank corresponds to inputs on the first rank.
            return_outputs: Whether to return outputs on the first/last ranks. Default: ``False``.

        Returns: (loss, outputs)
            loss: Loss for the batch.
                loss on the first rank corresponds to inputs on the last rank.
                loss on the last rank corresponds to inputs on the first rank.
                Otherwise: ``None``.
            outputs: Returned only if ``return_outputs=True``.
                outputs on the first rank corresponds to inputs on the last rank.
                outputs on the last rank corresponds to inputs on the first rank.
                Otherwise: ``None``.

        """
        assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



dualpipe/dualpipev.py [270:311]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            return

        chunk_id = self.current_send_b_chunk_id[phase]
        self.current_send_b_chunk_id[phase] += 1
        tensors = self.input_grad_chunks[phase][chunk_id]
        self.input_grad_chunks[phase][chunk_id] = None

        comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group)

    def _commit_and_wait_comm(self) -> None:
        if not self.comm_ops:
            return
        reqs = dist.batch_isend_irecv(self.comm_ops)
        for req in reqs:
            req.wait()
        self.comm_ops = []
        self._free_tensors()

    def step(
        self,
        *inputs: Optional[torch.Tensor],
        num_chunks: int = 0,
        criterion: Optional[Callable] = None,
        labels: List[Optional[torch.Tensor]] = [],
        return_outputs: bool = False,
    ) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
        """
        Execute a training or inference step.

        Arguments:
            *inputs: Module inputs. Required only on the first rank.
            num_chunks: The number of micro-batches.
            criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first rank.
            labels: Labels of the loss function. Required only on the first rank.
            return_outputs: Whether to return outputs on the first rank. Default: ``False``.

        Returns: (loss, outputs)
            loss: Loss for the batch. Returned only on the first rank.
            outputs: Module outputs. Returned only if ``return_outputs=True`` and on the first rank.

        """
        assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



