dualpipe/dualpipev.py (306 lines of code) (raw):

from typing import Tuple, List, Union, Callable, Optional import torch import torch.nn as nn import torch.distributed as dist import dualpipe.comm as comm from dualpipe.utils import WeightGradStore, run_backward, scatter, gather class DualPipeV(nn.Module): def __init__( self, modules: Tuple[nn.Module, nn.Module], batch_dim: int = 0, process_group: Optional[dist.ProcessGroup] = None, rank_mapping: Optional[List[int]] = None, ) -> None: super().__init__() assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device()) self.module = nn.ModuleList(modules) self.overlapped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlapped_forward_backward") self.batch_dim = batch_dim self.group = process_group or dist.distributed_c10d._get_default_group() self.num_ranks = self.group.size() # rank_mapping: Map rank in process_group to actual pp rank. # rank_inverse_mapping: Map actual pp rank to rank in process_group. if rank_mapping is None: rank_mapping = list(range(self.num_ranks)) rank_inverse_mapping = [None] * (self.num_ranks + 1) for i in range(self.num_ranks): rank_inverse_mapping[rank_mapping[i]] = i self.rank = rank_mapping[self.group.rank()] self.prev_rank = rank_inverse_mapping[self.rank - 1] self.next_rank = rank_inverse_mapping[self.rank + 1] self.is_first_rank = self.rank == 0 self.is_last_rank = self.rank == self.num_ranks - 1 def _reset_states(self) -> None: WeightGradStore.clear() self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) self.labels: List[List[torch.Tensor]] = None self.loss_chunks: List[torch.Tensor] = [] self.criterion: Callable = None self.current_f_chunk_id: List[int] = [0, 0] self.current_b_chunk_id: List[int] = [0, 0] self.current_send_f_chunk_id: List[int] = [0, 0] self.current_send_b_chunk_id: List[int] = [0, 0] self.current_recv_f_chunk_id: List[int] = [0, 0] self.current_recv_b_chunk_id: List[int] = [0, 0] self.comm_ops: List[dist.P2POp] = [] self.to_free: List[torch.Tensor] = [] def _forward_compute_chunk(self, phase: int) -> None: chunk_id = self.current_f_chunk_id[phase] self.current_f_chunk_id[phase] += 1 inputs = self.input_chunks[phase][chunk_id] if self.forward_only: self.input_chunks[phase][chunk_id] = None is_last_stage = (self.is_first_rank and phase == 1) outputs = self.module[phase](*inputs) outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs if is_last_stage and self.criterion is not None: labels = self.labels[chunk_id] loss = self.criterion(*outputs, *labels) self.loss_chunks.append(loss) if self.is_last_rank and phase == 0: self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs]) if (not is_last_stage) or self.return_outputs: self.output_chunks[phase].append(outputs) def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None: if self.forward_only: return chunk_id = self.current_b_chunk_id[phase] self.current_b_chunk_id[phase] += 1 is_last_stage = (self.is_first_rank and phase == 1) WeightGradStore.enabled = enable_zb if is_last_stage: loss = self.loss_chunks[chunk_id] loss.backward() loss.detach_() else: outputs = self.output_chunks[phase][chunk_id] if not self.return_outputs: self.output_chunks[phase][chunk_id] = None output_grads = self.output_grad_chunks[phase][chunk_id] self.output_grad_chunks[phase][chunk_id] = None non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None] outputs, output_grads = list(zip(*non_empty)) if len(outputs) > 0: run_backward(outputs, output_grads) WeightGradStore.enabled = False if enable_zb: WeightGradStore.flush() inputs = self.input_chunks[phase][chunk_id] self.input_chunks[phase][chunk_id] = None input_grads = [t.grad for t in inputs] if self.is_last_rank and phase == 1: self.output_grad_chunks[0].append(input_grads) else: self.input_grad_chunks[phase].append(input_grads) def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None: if self.forward_only: self._forward_compute_chunk(phase0) return if not self.overlapped_forward_backward: self._forward_compute_chunk(phase0) self._backward_compute_chunk(phase1) return # pre-forward chunk_id0 = self.current_f_chunk_id[phase0] self.current_f_chunk_id[phase0] += 1 module0 = self.module[phase0] inputs0 = self.input_chunks[phase0][chunk_id0] is_last_stage0 = (self.is_first_rank and phase0 == 1) if is_last_stage0 and self.criterion is not None: labels0 = self.labels[chunk_id0] criterion0 = self.criterion else: labels0 = [] criterion0 = None # pre-backward chunk_id1 = self.current_b_chunk_id[phase1] self.current_b_chunk_id[phase1] += 1 module1 = self.module[phase1] is_last_stage1 = (self.is_first_rank and phase1 == 1) if is_last_stage1: loss1 = self.loss_chunks[chunk_id1] outputs1 = [] output_grads1 = [] else: loss1 = None outputs1 = self.output_chunks[phase1][chunk_id1] if not self.return_outputs: self.output_chunks[phase1][chunk_id1] = None output_grads1 = self.output_grad_chunks[phase1][chunk_id1] self.output_grad_chunks[phase1][chunk_id1] = None non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None] outputs1, output_grads1 = list(zip(*non_empty)) # forward & backward outputs0, loss0 = type(module0).overlapped_forward_backward( module0, inputs0, criterion0, labels0, module1, loss1, outputs1, output_grads1, ) # post-forward if self.is_last_rank and phase0 == 0: self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs0]) if (not is_last_stage0) or self.return_outputs: self.output_chunks[phase0].append(outputs0) if is_last_stage0 and self.criterion is not None: self.loss_chunks.append(loss0) # post-backward inputs = self.input_chunks[phase1][chunk_id1] self.input_chunks[phase1][chunk_id1] = None input_grads1 = [t.grad for t in inputs] if self.is_last_rank and phase1 == 1: self.output_grad_chunks[0].append(input_grads1) else: self.input_grad_chunks[phase1].append(input_grads1) def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None: if recv: self._recv_forward(phase) self._commit_and_wait_comm() self._forward_compute_chunk(phase) if send: self._send_forward(phase) def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None: if recv: self._recv_backward(phase) self._commit_and_wait_comm() self._backward_compute_chunk(phase, enable_zb) if send: self._send_backward(phase) def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None: if recv0: self._recv_forward(phase0) self._recv_backward(phase1) self._commit_and_wait_comm() self._forward_backward_compute_chunk(phase0, phase1) self._send_forward(phase0) self._send_backward(phase1) def _weight_chunk(self) -> None: if self.forward_only: return self._commit_and_wait_comm() # Assume FIFO WeightGradStore.pop() def _free_tensors(self) -> None: for tensor in self.to_free: assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}" tensor.data = torch.Tensor() self.to_free = [] def _recv_forward(self, phase: int) -> None: if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1): return self.current_recv_f_chunk_id[phase] += 1 tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group) self.input_chunks[phase].append(tensors) def _send_forward(self, phase: int) -> None: if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0): return chunk_id = self.current_send_f_chunk_id[phase] self.current_send_f_chunk_id[phase] += 1 tensors = self.output_chunks[phase][chunk_id] comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group) if not self.return_outputs: self.to_free.extend(tensors) def _recv_backward(self, phase: int) -> None: if self.forward_only: return if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0): return self.current_recv_b_chunk_id[phase] += 1 tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group) self.output_grad_chunks[phase].append(tensors) def _send_backward(self, phase: int) -> None: if self.forward_only: return if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1): return chunk_id = self.current_send_b_chunk_id[phase] self.current_send_b_chunk_id[phase] += 1 tensors = self.input_grad_chunks[phase][chunk_id] self.input_grad_chunks[phase][chunk_id] = None comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group) def _commit_and_wait_comm(self) -> None: if not self.comm_ops: return reqs = dist.batch_isend_irecv(self.comm_ops) for req in reqs: req.wait() self.comm_ops = [] self._free_tensors() def step( self, *inputs: Optional[torch.Tensor], num_chunks: int = 0, criterion: Optional[Callable] = None, labels: List[Optional[torch.Tensor]] = [], return_outputs: bool = False, ) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]: """ Execute a training or inference step. Arguments: *inputs: Module inputs. Required only on the first rank. num_chunks: The number of micro-batches. criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first rank. labels: Labels of the loss function. Required only on the first rank. return_outputs: Whether to return outputs on the first rank. Default: ``False``. Returns: (loss, outputs) loss: Loss for the batch. Returned only on the first rank. outputs: Module outputs. Returned only if ``return_outputs=True`` and on the first rank. """ assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \ "You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before executing a step." self.forward_only = not torch.is_grad_enabled() self.return_outputs = return_outputs rank = self.rank num_ranks = self.num_ranks assert num_chunks > 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}" if not self.forward_only and self.is_first_rank: assert criterion is not None self._reset_states() if self.is_first_rank: self.input_chunks = (scatter(inputs, num_chunks, self.batch_dim), []) self.labels = scatter(labels, num_chunks, self.batch_dim) self.criterion = criterion # Step 1: nF0 step_1 = (num_ranks - rank - 1) * 2 for i in range(step_1): self._forward_chunk(0) # Step 2: nF0F1 step_2 = rank + 1 self._recv_forward(0) for i in range(step_2): self._forward_chunk(0, recv=False, send=False) self._recv_forward(0) self._forward_chunk(1, send=(not self.is_last_rank) or (i < step_2 - 1)) self._send_forward(0) # Step 3: nB1W1F1 (Use zero bubble) step_3 = num_ranks - rank - 1 for i in range(step_3): self._backward_chunk(1, enable_zb=True) self._recv_forward(1) self._weight_chunk() self._forward_chunk(1, recv=False) # Step 4 (Main step): nF0B1F1B0 step_4 = num_chunks - num_ranks * 2 + rank + 1 for i in range(step_4): if i == 0: if self.is_last_rank: # NOTE: We don't overlap these two chunks to further reduce bubble size. self._forward_chunk(0, recv=False, send=False) self._send_forward(1) self._backward_chunk(1, send=False) self._send_forward(0) self._send_backward(1) else: self._forward_backward_chunk(0, 1, recv0=False) else: self._forward_backward_chunk(0, 1) self._forward_backward_chunk(1, 0) # Step 5: nB1F1B0 step_5 = num_ranks - rank - 1 for i in range(step_5): self._backward_chunk(1) self._forward_backward_chunk(1, 0) # Step 6: nB1B0 (The second half of the chunks use zero bubble) step_6 = rank + 1 enable_zb = False for i in range(step_6): if i == step_6 // 2 and rank % 2 == 1: enable_zb = True self._backward_chunk(1, enable_zb=enable_zb) if i == step_6 // 2 and rank % 2 == 0: enable_zb = True self._backward_chunk(0, enable_zb=enable_zb) # Step 7: nWB0 (Use zero bubble) step_7 = num_ranks - rank - 1 for i in range(step_7): self._weight_chunk() self._backward_chunk(0, enable_zb=True) # Step 8: nW step_8 = rank + 1 for i in range(step_8): self._weight_chunk() assert WeightGradStore.funcs_queue.empty() self._commit_and_wait_comm() loss, outputs = None, None if self.is_first_rank: if criterion is not None: loss = torch.stack(self.loss_chunks) if return_outputs: outputs = gather(self.output_chunks[1], self.batch_dim) if len(outputs) == 1: outputs = outputs[0] self._reset_states() return loss, outputs