in src/nanotron/parallel/pipeline_parallel/block.py [0:0]
def forward(self, **kwargs):
"""Forward pass
We use a mechanism using TensorPointers to pass Tensors around
All non Tensor object or TensorPointers are considered pass-through, they are never meant to be communicated cross process
:param kwargs: Dict[str, Union[TensorPointer, torch.Tensor, Any]]
:return: Dict[str, Union[TensorPointer, torch.Tensor, Any]
"""
assert self.module_input_keys == set(
kwargs.keys()
), f"Expected {self.module_input_keys}, got {set(kwargs.keys())}"
sorted_kwargs = sorted(kwargs.items(), key=get_sort_key(dist.get_rank(self.p2p.pg)))
# Is the current rank is not the one running the compute
if dist.get_rank(self.p2p.pg) != self.rank:
# TODO(kunhao): A better design is to pop this up for both if else branches.
batch_send_recv = BatchTensorSendRecvState(self.p2p)
# Send activations from other devices to local rank
for name, tensor in sorted_kwargs:
if isinstance(tensor, TensorPointer):
# Current rank is neither the rank holding the data nor the rank responsible for computing block
continue
else:
assert isinstance(tensor, torch.Tensor)
# We need to send the tensor to the rank that actually runs the compute
if self.pipeline_state is not None:
send_to_pipeline_state_buffer(
tensor,
to_rank=self.rank,
p2p=self.p2p,
pipeline_state=self.pipeline_state,
)
continue
if tensor.requires_grad is True:
raise ValueError(
f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
)
batch_send_recv.add_send(tensor=tensor, to_rank=self.rank)
batch_send_recv.flush()
# Return that the outputs are all in the rank responsible for computing block
# TODO @thomasw21: Figure out a way to build dummy_input in a generic sense, and remove the necessity to have Dict[str, torch.Tensor] as output
return {k: TensorPointer(group_rank=self.rank) for k in self.module_output_keys}
# Recv activations from other devices to local rank
new_kwargs: Dict[str, torch.Tensor] = {}
name_to_recv_id = {}
batch_send_recv = BatchTensorSendRecvState(self.p2p)
for name, tensor in sorted_kwargs:
if isinstance(tensor, TensorPointer):
# Current rank is the one running the compute, we need to query the tensor
# new_kwargs[name] = recv_tensor(from_rank=tensor.group_rank, p2p=self.p2p)
# This assumes that prior communication was already done
# In case of interleaved 1f1b, if this is the second model chunk, then we need to send the previous activations before receiving the current activations
if isinstance(self.pipeline_state, PipelineTrainBatchState):
for _ in range(len(self.pipeline_state.microbatches_activations_to_send)):
send_activation = self.pipeline_state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
if self.pipeline_state is not None:
new_kwargs[name] = recv_from_pipeline_state_buffer(
from_rank=tensor.group_rank,
p2p=self.p2p,
pipeline_state=self.pipeline_state,
)
continue
# We don't store result in a buffer
recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank)
name_to_recv_id[name] = recv_id
else:
new_kwargs[name] = tensor
# Run receiving communications
recv_tensors = batch_send_recv.flush()
assert len(recv_tensors) == len(name_to_recv_id)
for name, recv_id in name_to_recv_id.items():
assert name not in new_kwargs
new_tensor = recv_tensors[recv_id]
if new_tensor.requires_grad is True:
raise ValueError(
f"Pipeline engine is None and tensor requires grad. Tried receiving a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
)
new_kwargs[name] = new_tensor
output = self.pp_block(**new_kwargs)
# Helper for functions that return tensors
if isinstance(output, torch.Tensor):
assert len(self.module_output_keys) == 1
output = {next(iter(self.module_output_keys)): output}
assert isinstance(output, dict), "Modules within a Pipeline Block have to return a Dict[str, torch.Tensor]"
assert self.module_output_keys == set(
output.keys()
), f"Expected {self.module_output_keys}, got {set(output.keys())}"
return output