dualpipe/dualpipev.py (306 lines of code) (raw):
from typing import Tuple, List, Union, Callable, Optional
import torch
import torch.nn as nn
import torch.distributed as dist
import dualpipe.comm as comm
from dualpipe.utils import WeightGradStore, run_backward, scatter, gather
class DualPipeV(nn.Module):
def __init__(
self,
modules: Tuple[nn.Module, nn.Module],
batch_dim: int = 0,
process_group: Optional[dist.ProcessGroup] = None,
rank_mapping: Optional[List[int]] = None,
) -> None:
super().__init__()
assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device())
self.module = nn.ModuleList(modules)
self.overlapped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlapped_forward_backward")
self.batch_dim = batch_dim
self.group = process_group or dist.distributed_c10d._get_default_group()
self.num_ranks = self.group.size()
# rank_mapping: Map rank in process_group to actual pp rank.
# rank_inverse_mapping: Map actual pp rank to rank in process_group.
if rank_mapping is None:
rank_mapping = list(range(self.num_ranks))
rank_inverse_mapping = [None] * (self.num_ranks + 1)
for i in range(self.num_ranks):
rank_inverse_mapping[rank_mapping[i]] = i
self.rank = rank_mapping[self.group.rank()]
self.prev_rank = rank_inverse_mapping[self.rank - 1]
self.next_rank = rank_inverse_mapping[self.rank + 1]
self.is_first_rank = self.rank == 0
self.is_last_rank = self.rank == self.num_ranks - 1
def _reset_states(self) -> None:
WeightGradStore.clear()
self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.labels: List[List[torch.Tensor]] = None
self.loss_chunks: List[torch.Tensor] = []
self.criterion: Callable = None
self.current_f_chunk_id: List[int] = [0, 0]
self.current_b_chunk_id: List[int] = [0, 0]
self.current_send_f_chunk_id: List[int] = [0, 0]
self.current_send_b_chunk_id: List[int] = [0, 0]
self.current_recv_f_chunk_id: List[int] = [0, 0]
self.current_recv_b_chunk_id: List[int] = [0, 0]
self.comm_ops: List[dist.P2POp] = []
self.to_free: List[torch.Tensor] = []
def _forward_compute_chunk(self, phase: int) -> None:
chunk_id = self.current_f_chunk_id[phase]
self.current_f_chunk_id[phase] += 1
inputs = self.input_chunks[phase][chunk_id]
if self.forward_only:
self.input_chunks[phase][chunk_id] = None
is_last_stage = (self.is_first_rank and phase == 1)
outputs = self.module[phase](*inputs)
outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs
if is_last_stage and self.criterion is not None:
labels = self.labels[chunk_id]
loss = self.criterion(*outputs, *labels)
self.loss_chunks.append(loss)
if self.is_last_rank and phase == 0:
self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs])
if (not is_last_stage) or self.return_outputs:
self.output_chunks[phase].append(outputs)
def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None:
if self.forward_only:
return
chunk_id = self.current_b_chunk_id[phase]
self.current_b_chunk_id[phase] += 1
is_last_stage = (self.is_first_rank and phase == 1)
WeightGradStore.enabled = enable_zb
if is_last_stage:
loss = self.loss_chunks[chunk_id]
loss.backward()
loss.detach_()
else:
outputs = self.output_chunks[phase][chunk_id]
if not self.return_outputs:
self.output_chunks[phase][chunk_id] = None
output_grads = self.output_grad_chunks[phase][chunk_id]
self.output_grad_chunks[phase][chunk_id] = None
non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None]
outputs, output_grads = list(zip(*non_empty))
if len(outputs) > 0:
run_backward(outputs, output_grads)
WeightGradStore.enabled = False
if enable_zb:
WeightGradStore.flush()
inputs = self.input_chunks[phase][chunk_id]
self.input_chunks[phase][chunk_id] = None
input_grads = [t.grad for t in inputs]
if self.is_last_rank and phase == 1:
self.output_grad_chunks[0].append(input_grads)
else:
self.input_grad_chunks[phase].append(input_grads)
def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None:
if self.forward_only:
self._forward_compute_chunk(phase0)
return
if not self.overlapped_forward_backward:
self._forward_compute_chunk(phase0)
self._backward_compute_chunk(phase1)
return
# pre-forward
chunk_id0 = self.current_f_chunk_id[phase0]
self.current_f_chunk_id[phase0] += 1
module0 = self.module[phase0]
inputs0 = self.input_chunks[phase0][chunk_id0]
is_last_stage0 = (self.is_first_rank and phase0 == 1)
if is_last_stage0 and self.criterion is not None:
labels0 = self.labels[chunk_id0]
criterion0 = self.criterion
else:
labels0 = []
criterion0 = None
# pre-backward
chunk_id1 = self.current_b_chunk_id[phase1]
self.current_b_chunk_id[phase1] += 1
module1 = self.module[phase1]
is_last_stage1 = (self.is_first_rank and phase1 == 1)
if is_last_stage1:
loss1 = self.loss_chunks[chunk_id1]
outputs1 = []
output_grads1 = []
else:
loss1 = None
outputs1 = self.output_chunks[phase1][chunk_id1]
if not self.return_outputs:
self.output_chunks[phase1][chunk_id1] = None
output_grads1 = self.output_grad_chunks[phase1][chunk_id1]
self.output_grad_chunks[phase1][chunk_id1] = None
non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None]
outputs1, output_grads1 = list(zip(*non_empty))
# forward & backward
outputs0, loss0 = type(module0).overlapped_forward_backward(
module0, inputs0, criterion0, labels0,
module1, loss1, outputs1, output_grads1,
)
# post-forward
if self.is_last_rank and phase0 == 0:
self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs0])
if (not is_last_stage0) or self.return_outputs:
self.output_chunks[phase0].append(outputs0)
if is_last_stage0 and self.criterion is not None:
self.loss_chunks.append(loss0)
# post-backward
inputs = self.input_chunks[phase1][chunk_id1]
self.input_chunks[phase1][chunk_id1] = None
input_grads1 = [t.grad for t in inputs]
if self.is_last_rank and phase1 == 1:
self.output_grad_chunks[0].append(input_grads1)
else:
self.input_grad_chunks[phase1].append(input_grads1)
def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None:
if recv:
self._recv_forward(phase)
self._commit_and_wait_comm()
self._forward_compute_chunk(phase)
if send:
self._send_forward(phase)
def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None:
if recv:
self._recv_backward(phase)
self._commit_and_wait_comm()
self._backward_compute_chunk(phase, enable_zb)
if send:
self._send_backward(phase)
def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None:
if recv0:
self._recv_forward(phase0)
self._recv_backward(phase1)
self._commit_and_wait_comm()
self._forward_backward_compute_chunk(phase0, phase1)
self._send_forward(phase0)
self._send_backward(phase1)
def _weight_chunk(self) -> None:
if self.forward_only:
return
self._commit_and_wait_comm()
# Assume FIFO
WeightGradStore.pop()
def _free_tensors(self) -> None:
for tensor in self.to_free:
assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}"
tensor.data = torch.Tensor()
self.to_free = []
def _recv_forward(self, phase: int) -> None:
if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1):
return
self.current_recv_f_chunk_id[phase] += 1
tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group)
self.input_chunks[phase].append(tensors)
def _send_forward(self, phase: int) -> None:
if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0):
return
chunk_id = self.current_send_f_chunk_id[phase]
self.current_send_f_chunk_id[phase] += 1
tensors = self.output_chunks[phase][chunk_id]
comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group)
if not self.return_outputs:
self.to_free.extend(tensors)
def _recv_backward(self, phase: int) -> None:
if self.forward_only:
return
if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0):
return
self.current_recv_b_chunk_id[phase] += 1
tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group)
self.output_grad_chunks[phase].append(tensors)
def _send_backward(self, phase: int) -> None:
if self.forward_only:
return
if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1):
return
chunk_id = self.current_send_b_chunk_id[phase]
self.current_send_b_chunk_id[phase] += 1
tensors = self.input_grad_chunks[phase][chunk_id]
self.input_grad_chunks[phase][chunk_id] = None
comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group)
def _commit_and_wait_comm(self) -> None:
if not self.comm_ops:
return
reqs = dist.batch_isend_irecv(self.comm_ops)
for req in reqs:
req.wait()
self.comm_ops = []
self._free_tensors()
def step(
self,
*inputs: Optional[torch.Tensor],
num_chunks: int = 0,
criterion: Optional[Callable] = None,
labels: List[Optional[torch.Tensor]] = [],
return_outputs: bool = False,
) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
"""
Execute a training or inference step.
Arguments:
*inputs: Module inputs. Required only on the first rank.
num_chunks: The number of micro-batches.
criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first rank.
labels: Labels of the loss function. Required only on the first rank.
return_outputs: Whether to return outputs on the first rank. Default: ``False``.
Returns: (loss, outputs)
loss: Loss for the batch. Returned only on the first rank.
outputs: Module outputs. Returned only if ``return_outputs=True`` and on the first rank.
"""
assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \
"You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before executing a step."
self.forward_only = not torch.is_grad_enabled()
self.return_outputs = return_outputs
rank = self.rank
num_ranks = self.num_ranks
assert num_chunks > 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}"
if not self.forward_only and self.is_first_rank:
assert criterion is not None
self._reset_states()
if self.is_first_rank:
self.input_chunks = (scatter(inputs, num_chunks, self.batch_dim), [])
self.labels = scatter(labels, num_chunks, self.batch_dim)
self.criterion = criterion
# Step 1: nF0
step_1 = (num_ranks - rank - 1) * 2
for i in range(step_1):
self._forward_chunk(0)
# Step 2: nF0F1
step_2 = rank + 1
self._recv_forward(0)
for i in range(step_2):
self._forward_chunk(0, recv=False, send=False)
self._recv_forward(0)
self._forward_chunk(1, send=(not self.is_last_rank) or (i < step_2 - 1))
self._send_forward(0)
# Step 3: nB1W1F1 (Use zero bubble)
step_3 = num_ranks - rank - 1
for i in range(step_3):
self._backward_chunk(1, enable_zb=True)
self._recv_forward(1)
self._weight_chunk()
self._forward_chunk(1, recv=False)
# Step 4 (Main step): nF0B1F1B0
step_4 = num_chunks - num_ranks * 2 + rank + 1
for i in range(step_4):
if i == 0:
if self.is_last_rank:
# NOTE: We don't overlap these two chunks to further reduce bubble size.
self._forward_chunk(0, recv=False, send=False)
self._send_forward(1)
self._backward_chunk(1, send=False)
self._send_forward(0)
self._send_backward(1)
else:
self._forward_backward_chunk(0, 1, recv0=False)
else:
self._forward_backward_chunk(0, 1)
self._forward_backward_chunk(1, 0)
# Step 5: nB1F1B0
step_5 = num_ranks - rank - 1
for i in range(step_5):
self._backward_chunk(1)
self._forward_backward_chunk(1, 0)
# Step 6: nB1B0 (The second half of the chunks use zero bubble)
step_6 = rank + 1
enable_zb = False
for i in range(step_6):
if i == step_6 // 2 and rank % 2 == 1:
enable_zb = True
self._backward_chunk(1, enable_zb=enable_zb)
if i == step_6 // 2 and rank % 2 == 0:
enable_zb = True
self._backward_chunk(0, enable_zb=enable_zb)
# Step 7: nWB0 (Use zero bubble)
step_7 = num_ranks - rank - 1
for i in range(step_7):
self._weight_chunk()
self._backward_chunk(0, enable_zb=True)
# Step 8: nW
step_8 = rank + 1
for i in range(step_8):
self._weight_chunk()
assert WeightGradStore.funcs_queue.empty()
self._commit_and_wait_comm()
loss, outputs = None, None
if self.is_first_rank:
if criterion is not None:
loss = torch.stack(self.loss_chunks)
if return_outputs:
outputs = gather(self.output_chunks[1], self.batch_dim)
if len(outputs) == 1:
outputs = outputs[0]
self._reset_states()
return loss, outputs