in dualpipe/dualpipe.py [0:0]
def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None:
if self.forward_only:
return
phase ^= self.is_in_second_half
chunk_id = self.current_b_chunk_id[phase]
self.current_b_chunk_id[phase] += 1
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
WeightGradStore.enabled = enable_zb
if is_last_stage:
loss = self.loss_chunks[chunk_id]
loss.backward()
loss.detach_()
else:
outputs = self.output_chunks[phase][chunk_id]
if not self.return_outputs:
self.output_chunks[phase][chunk_id] = None
output_grads = self.output_grad_chunks[phase][chunk_id]
self.output_grad_chunks[phase][chunk_id] = None
non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None]
outputs, output_grads = list(zip(*non_empty))
if len(outputs) > 0:
run_backward(outputs, output_grads)
WeightGradStore.enabled = False
if enable_zb:
WeightGradStore.flush()
inputs = self.input_chunks[phase][chunk_id]
self.input_chunks[phase][chunk_id] = None
input_grads = [t.grad for t in inputs]
self.input_grad_chunks[phase].append(input_grads)