in dualpipe/dualpipe.py [0:0]
def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None:
if self.forward_only:
self._forward_compute_chunk(phase0)
return
if not self.overlapped_forward_backward:
self._forward_compute_chunk(phase0)
self._backward_compute_chunk(phase1)
return
# pre-forward
phase0 ^= self.is_in_second_half
chunk_id0 = self.current_f_chunk_id[phase0]
self.current_f_chunk_id[phase0] += 1
module0 = self.module[phase0]
inputs0 = self.input_chunks[phase0][chunk_id0]
is_last_stage0 = (self.is_first_rank and phase0 == 1) or (self.is_last_rank and phase0 == 0)
if is_last_stage0 and self.criterion is not None:
labels0 = self.labels[phase0][chunk_id0]
criterion0 = self.criterion
else:
labels0 = []
criterion0 = None
# pre-backward
phase1 ^= self.is_in_second_half
chunk_id1 = self.current_b_chunk_id[phase1]
self.current_b_chunk_id[phase1] += 1
module1 = self.module[phase1]
is_last_stage1 = (self.is_first_rank and phase1 == 1) or (self.is_last_rank and phase1 == 0)
if is_last_stage1:
loss1 = self.loss_chunks[chunk_id1]
outputs1 = []
output_grads1 = []
else:
loss1 = None
outputs1 = self.output_chunks[phase1][chunk_id1]
if not self.return_outputs:
self.output_chunks[phase1][chunk_id1] = None
output_grads1 = self.output_grad_chunks[phase1][chunk_id1]
self.output_grad_chunks[phase1][chunk_id1] = None
non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None]
outputs1, output_grads1 = list(zip(*non_empty))
# forward & backward
outputs0, loss0 = type(module0).overlapped_forward_backward(
module0, inputs0, criterion0, labels0,
module1, loss1, outputs1, output_grads1,
)
# post-forward
if (not is_last_stage0) or self.return_outputs:
self.output_chunks[phase0].append(outputs0)
if is_last_stage0 and self.criterion is not None:
self.loss_chunks.append(loss0)
# post-backward
inputs = self.input_chunks[phase1][chunk_id1]
self.input_chunks[phase1][chunk_id1] = None
input_grads1 = [t.grad for t in inputs]
self.input_grad_chunks[phase1].append(input_grads1)