in dualpipe/dualpipe.py [0:0]
def _forward_compute_chunk(self, phase: int) -> None:
phase ^= self.is_in_second_half
chunk_id = self.current_f_chunk_id[phase]
self.current_f_chunk_id[phase] += 1
inputs = self.input_chunks[phase][chunk_id]
if self.forward_only:
self.input_chunks[phase][chunk_id] = None
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
outputs = self.module[phase](*inputs)
outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs
if is_last_stage and self.criterion is not None:
labels = self.labels[phase][chunk_id]
loss = self.criterion(*outputs, *labels)
self.loss_chunks.append(loss)
if (not is_last_stage) or self.return_outputs:
self.output_chunks[phase].append(outputs)