step6_data_parallel_bucket/data_parallel.py [10:257]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class DataParallelNaive(nn.Module):
    # Shouldn't be use with gradient accumulation (because it will accumulates gradient in bfloat16 and not float32)
    def __init__(self, module):
        super().__init__()
        self.module = module
        # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
        self.require_backward_grad_sync = True
        self.register_backward_hook(self._allreduce_grads)
    
    def forward(self, *inputs, **kwargs):
        return self.module(*inputs, **kwargs)
    
    def register_backward_hook(self, hook):
        """Registers a backward hook for all parameters of the model that require gradients.""" 
        for p in self.module.parameters():
            if p.requires_grad is True:
                p.register_hook(hook)
                
    def _allreduce_grads(self, grad):
        """Performs an all-reduce operation to synchronize gradients across multiple processes."""
        # No synchronization needed during gradient accumulation, except at the final accumulation step.
        if self.require_backward_grad_sync:
            dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
            grad /= pgm.process_group_manager.dp_world_size
        return grad

### end Data Parallel (naive)

### begin Data Parallel (bucket)
class Bucket:
    def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None:
        # Set of parameters in this bucket.
        self.params = set(params)
        # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
        self.params_with_grad_ready = set()
        # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
        self.grad_data = grad_data
        # Process group for gradient synchronization.
        self.process_group = process_group
        self.process_group_size = dist.get_world_size(group=self.process_group) 
        # Handle for the async allreduce operation.
        self.handle = None
        
        self.reset()
    
    def sync_gradient(self) -> None:
        """Launch an asynchronous all-reduce operation to synchronize gradients across processes."""
        assert self.handle is None
        self.grad_data /= self.process_group_size
        self.handle = dist.all_reduce(self.grad_data, group=self.process_group, async_op=True)
    
    def reset(self) -> None:
        """Reset the bucket to its initial state. Typically called after the gradient synchronization is finished."""
        self.handle = None
        # Clear the set of parameters ready for gradient synchronization.
        self.params_with_grad_ready.clear()
        # Zero the gradient tensor.
        self.grad_data.zero_()

    def wait(self) -> None:
        """wait for the allreduce operation to finish"""
        assert self.handle is not None, "You should launch an allreduce operation before waiting for it to finish"
        # Block until the all-reduce operation finishes.
        self.handle.wait()

    def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
        """Mark a parameter as ready for gradient synchronization. Launches synchronization when all parameters in the bucket have their gradients ready."""
        assert param in self.params and param not in self.params_with_grad_ready
        self.params_with_grad_ready.add(param)
        # When all parameters in the bucket have their gradients ready, synchronize gradients
        if len(self.params_with_grad_ready) == len(self.params):
            self.sync_gradient()

class BucketManager:
    def __init__(self, params: List[torch.nn.Parameter], process_group: torch.distributed.ProcessGroup, bucket_size: int, grad_type: torch.dtype = torch.float32) -> None:
        # Convert parameter generator to a list.
        self.params = list(params)
        # List of buckets.
        self.buckets = []
        self.process_group = process_group
        self.process_group_size = dist.get_world_size(group=self.process_group)
        # Map each parameter to its corresponding bucket/place (start, end, bucket_idx).
        self.params_to_bucket_location = {}
        # Actual sizes of each bucket.
        self.bucket_size = bucket_size
        self.bucket_sizes = None
        # List of tensors to store gradients, one tensor per bucket.
        self.grad_data_list = []
        self.grad_type = grad_type
        # Divide gradients into buckets based on the provided bucket size.
        self._initialize_buckets()
    
    def _initialize_buckets(self) -> None:
        """Divides model parameters into buckets for gradient synchronization based on the bucket size."""
        cur_bucket_size = 0 
        cur_bucket_idx = 0
        
        # Assign parameters to buckets. 
        for param in self.params:
            if not param.requires_grad:
                continue
            
            # If the bucket is empty, add the parameter to the bucket.
            if cur_bucket_size == 0:
                self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
                cur_bucket_size = param.numel()
                continue
            
            # If the parameter cannot fit in the current bucket, create a new bucket
            if cur_bucket_size + param.numel() > self.bucket_size:
                cur_bucket_idx += 1
                self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
                cur_bucket_size = param.numel()
            else:
                self.params_to_bucket_location[param] = (cur_bucket_size, cur_bucket_size + param.numel(), cur_bucket_idx)
                cur_bucket_size += param.numel()

        # Gather information about the bucket sizes and the parameters in each bucket
        bucket_sizes = [0] * (cur_bucket_idx + 1)
        buckets_to_params = [[] for _ in range(cur_bucket_idx + 1)]
        for param, (_, end, idx) in self.params_to_bucket_location.items():
            bucket_sizes[idx] = max(bucket_sizes[idx], end)
            buckets_to_params[idx].append(param)
        
        # Create tensors for storing gradients and initialize Bucket objects.
        for i in range(len(bucket_sizes)):
            self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device='cuda'))
            self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group))
        
        # Create gradient views for each parameter.
        for param in self.params[::-1]:
            if not param.requires_grad:
                continue
            data_start_index, data_end_index, bucket_id = self.params_to_bucket_location[param]
            # param.main_grad is used for gradient calculation
            param.main_grad = self._get_view_from_tensor(self.grad_data_list[bucket_id], param.shape, data_start_index, data_end_index)
            
    def _get_view_from_tensor(self, tensor: torch.Tensor, shape: torch.Size, start: int, end: int) -> torch.Tensor:
        return tensor[start:end].view(shape)
    
    def reset(self) -> None:
        # Reset all buckets by clearing the gradients and internal states.
        for bucket in self.buckets:
            bucket.reset()
    
    def wait(self) -> None:
        # Wait for all buckets to complete their gradient synchronization
        for bucket in self.buckets:
            bucket.wait()
    
    def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
        # Mark a parameter's gradient as ready for synchronization.
        bucket_idx = self.params_to_bucket_location[param][2]
        self.buckets[bucket_idx].mark_param_as_ready(param)

class DataParallelBucket(nn.Module):
    def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
        super().__init__()
        self.module = module
        # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
        self.require_backward_grad_sync = True
        grad_size = 2 # bfloat16 gradient: 2 bytes
        # number of gradients in one bucket
        # 1 MB = 1024 KB = 1024 * 1024 bytes
        bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size
        self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.dp_group, bucket_size, grad_type)
        self.register_backward_hook()
        # whether the callback for wait gradient synchronization is set
        self._post_backward_callback_set = False
        
    def forward(self, *inputs, **kwargs):
        return self.module(*inputs, **kwargs)

    def backward(self, input_tensor, output_tensor, output_tensor_grad):
        return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
    
    def get_flops(self, *args, **kwargs):
        return self.module.get_flops(*args, **kwargs)
    
    def register_backward_hook(self):
        """
        Registers a backward hook to manually accumulate and synchronize gradients.
        
        This hook serves two main purposes:
        1. PyTorch does not natively support gradient accumulation with mixed precision.
        2. After gradient accumulation, it flags parameters as ready for synchronization.
        
        The gradient accumulation functions are stored to prevent them from going out of scope.
        
        References:
        - https://github.com/NVIDIA/Megatron-LM/issues/690
        - https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html
        - https://arxiv.org/abs/2006.15704 (page 5)
        """
        self.grad_accs = []
        for param in self.module.parameters():
            if param.requires_grad:
                # Expand so we get access to grad_fn.
                param_tmp = param.expand_as(param)
                # Get the gradient accumulator function.
                grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
                grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
                self.grad_accs.append(grad_acc_fn)
                
    def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
        """
        Creates the a hook for each parameter to handle gradient accumulation and synchronization.
        """
        def param_hook(*unused):
            """
            The hook called after the gradient is ready. It performs the following:
            1. Accumulates the gradient into the main gradient.
            2. Adds a post-backward callback to wait for gradient synchronization completion.
            3. Marks the parameter as ready for synchronization.
            """
            if param.requires_grad:
                assert param.grad is not None
                param.main_grad.add_(param.grad.data) # accumulate the gradients
                param.grad = None
                
                # skip the gradient synchronization (gradient accumulation/PP micro batches)
                if self.require_backward_grad_sync:
                    # Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
                    # Callback is executed after the backward pass. It should be added per backward pass.
                    if not self._post_backward_callback_set:
                        torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
                        self._post_backward_callback_set = True
                        
                    # mark the parameter as ready for gradient synchronization. 
                    bucket_manager.mark_param_as_ready(param) 
        return param_hook

    def _post_backward(self):
        """
        A post-backward callback that waits for gradient synchronization to finish, then copies 
        the synchronized gradients back to the parameters' grad attribute.
        
        This method is called after the backward pass and before the optimizer step.
        """
        self.bucket_manager.wait()
        self._post_backward_callback_set = False
        # copy to params.grad so we can use the optimizer to update the parameters
        for p in self.module.parameters():
            if p.requires_grad:
                p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type.

    def reset(self):
        self.bucket_manager.reset()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



step8_pipeline_parallel_1f1b/data_parallel.py [10:255]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class DataParallelNaive(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
        self.require_backward_grad_sync = True
        self.register_backward_hook(self._allreduce_grads)
    
    def forward(self, *inputs, **kwargs):
        return self.module(*inputs, **kwargs)
    
    def register_backward_hook(self, hook):
        """Registers a backward hook for all parameters of the model that require gradients.""" 
        for p in self.module.parameters():
            if p.requires_grad is True:
                p.register_hook(hook)
                
    def _allreduce_grads(self, grad):
        """Performs an all-reduce operation to synchronize gradients across multiple processes."""
        # No synchronization needed during gradient accumulation, except at the final accumulation step.
        if self.require_backward_grad_sync:
            dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
            grad /= pgm.process_group_manager.dp_world_size
        return grad

### end Data Parallel (naive)

### begin Data Parallel (bucket)
class Bucket:
    def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None:
        # Set of parameters in this bucket.
        self.params = set(params)
        # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
        self.params_with_grad_ready = set()
        # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
        self.grad_data = grad_data
        # Process group for gradient synchronization.
        self.process_group = process_group
        self.process_group_size = dist.get_world_size(group=self.process_group) 
        # Handle for the async allreduce operation.
        self.handle = None
        
        self.reset()
    
    def sync_gradient(self) -> None:
        """Launch an asynchronous all-reduce operation to synchronize gradients across processes."""
        assert self.handle is None
        self.grad_data /= self.process_group_size
        self.handle = dist.all_reduce(self.grad_data, group=self.process_group, async_op=True)
    
    def reset(self) -> None:
        """Reset the bucket to its initial state. Typically called after the gradient synchronization is finished."""
        self.handle = None
        # Clear the set of parameters ready for gradient synchronization.
        self.params_with_grad_ready.clear()
        # Zero the gradient tensor.
        self.grad_data.zero_()

    def wait(self) -> None:
        """wait for the allreduce operation to finish"""
        assert self.handle is not None, "You should launch an allreduce operation before waiting for it to finish"
        # Block until the all-reduce operation finishes.
        self.handle.wait()

    def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
        """Mark a parameter as ready for gradient synchronization. Launches synchronization when all parameters in the bucket have their gradients ready."""
        assert param in self.params and param not in self.params_with_grad_ready
        self.params_with_grad_ready.add(param)
        # When all parameters in the bucket have their gradients ready, synchronize gradients
        if len(self.params_with_grad_ready) == len(self.params):
            self.sync_gradient()

class BucketManager:
    def __init__(self, params: List[torch.nn.Parameter], process_group: torch.distributed.ProcessGroup, bucket_size: int, grad_type: torch.dtype = torch.float32) -> None:
        # Convert parameter generator to a list.
        self.params = list(params)
        # List of buckets.
        self.buckets = []
        self.process_group = process_group
        self.process_group_size = dist.get_world_size(group=self.process_group)
        # Map each parameter to its corresponding bucket/place (start, end, bucket_idx).
        self.params_to_bucket_location = {}
        # Actual sizes of each bucket.
        self.bucket_size = bucket_size
        self.bucket_sizes = None
        # List of tensors to store gradients, one tensor per bucket.
        self.grad_data_list = []
        self.grad_type = grad_type
        # Divide gradients into buckets based on the provided bucket size.
        self._initialize_buckets()
    
    def _initialize_buckets(self) -> None:
        """Divides model parameters into buckets for gradient synchronization based on the bucket size."""
        cur_bucket_size = 0 
        cur_bucket_idx = 0
        
        # Assign parameters to buckets. 
        for param in self.params:
            if not param.requires_grad:
                continue
            
            # If the bucket is empty, add the parameter to the bucket.
            if cur_bucket_size == 0:
                self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
                cur_bucket_size = param.numel()
                continue
            
            # If the parameter cannot fit in the current bucket, create a new bucket
            if cur_bucket_size + param.numel() > self.bucket_size:
                cur_bucket_idx += 1
                self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
                cur_bucket_size = param.numel()
            else:
                self.params_to_bucket_location[param] = (cur_bucket_size, cur_bucket_size + param.numel(), cur_bucket_idx)
                cur_bucket_size += param.numel()

        # Gather information about the bucket sizes and the parameters in each bucket
        bucket_sizes = [0] * (cur_bucket_idx + 1)
        buckets_to_params = [[] for _ in range(cur_bucket_idx + 1)]
        for param, (_, end, idx) in self.params_to_bucket_location.items():
            bucket_sizes[idx] = max(bucket_sizes[idx], end)
            buckets_to_params[idx].append(param)
        
        # Create tensors for storing gradients and initialize Bucket objects.
        for i in range(len(bucket_sizes)):
            self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device='cuda'))
            self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group))
        
        # Create gradient views for each parameter.
        for param in self.params[::-1]:
            if not param.requires_grad:
                continue
            data_start_index, data_end_index, bucket_id = self.params_to_bucket_location[param]
            # param.main_grad is used for gradient calculation
            param.main_grad = self._get_view_from_tensor(self.grad_data_list[bucket_id], param.shape, data_start_index, data_end_index)
            
    def _get_view_from_tensor(self, tensor: torch.Tensor, shape: torch.Size, start: int, end: int) -> torch.Tensor:
        return tensor[start:end].view(shape)
    
    def reset(self) -> None:
        # Reset all buckets by clearing the gradients and internal states.
        for bucket in self.buckets:
            bucket.reset()
    
    def wait(self) -> None:
        # Wait for all buckets to complete their gradient synchronization
        for bucket in self.buckets:
            bucket.wait()
    
    def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
        # Mark a parameter's gradient as ready for synchronization.
        bucket_idx = self.params_to_bucket_location[param][2]
        self.buckets[bucket_idx].mark_param_as_ready(param)

class DataParallelBucket(nn.Module):
    def __init__(self, module, bucket_cap_mb=25, grad_type = torch.float32):
        super().__init__()
        self.module = module
        # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation
        self.require_backward_grad_sync = True
        grad_size = 2 # bfloat16 gradient: 2 bytes
        # number of gradients in one bucket
        bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size
        self.bucket_manager = BucketManager(module.parameters(), pgm.process_group_manager.dp_group, bucket_size, grad_type)
        self.register_backward_hook()
        # whether the callback for wait gradient synchronization is set
        self._post_backward_callback_set = False
        
    def forward(self, *inputs, **kwargs):
        return self.module(*inputs, **kwargs)

    def backward(self, input_tensor, output_tensor, output_tensor_grad):
        return self.module.backward(input_tensor, output_tensor, output_tensor_grad)
    
    def get_flops(self, *args, **kwargs):
        return self.module.get_flops(*args, **kwargs)
    
    def register_backward_hook(self):
        """
        Registers a backward hook to manually accumulate and synchronize gradients.
        
        This hook serves two main purposes:
        1. PyTorch does not natively support gradient accumulation with mixed precision.
        2. After gradient accumulation, it flags parameters as ready for synchronization.
        
        The gradient accumulation functions are stored to prevent them from going out of scope.
        
        References:
        - https://github.com/NVIDIA/Megatron-LM/issues/690
        - https://pytorch.org/docs/stable/generated/torch.autograd.graph.Node.register_hook.html
        - https://arxiv.org/abs/2006.15704 (page 5)
        """
        self.grad_accs = []
        for param in self.module.parameters():
            if param.requires_grad:
                # Expand so we get access to grad_fn.
                param_tmp = param.expand_as(param)
                # Get the gradient accumulator function.
                grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
                grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
                self.grad_accs.append(grad_acc_fn)
                
    def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
        """
        Creates the a hook for each parameter to handle gradient accumulation and synchronization.
        """
        def param_hook(*unused):
            """
            The hook called after the gradient is ready. It performs the following:
            1. Accumulates the gradient into the main gradient.
            2. Adds a post-backward callback to wait for gradient synchronization completion.
            3. Marks the parameter as ready for synchronization.
            """
            if param.requires_grad:
                assert param.grad is not None
                param.main_grad.add_(param.grad.data) # accumulate the gradients
                param.grad = None
                
                # skip the gradient synchronization (gradient accumulation/PP micro batches)
                if self.require_backward_grad_sync:
                    # Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
                    # Callback is executed after the backward pass. It should be added per backward pass.
                    if not self._post_backward_callback_set:
                        torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
                        self._post_backward_callback_set = True
                        
                    # mark the parameter as ready for gradient synchronization. 
                    bucket_manager.mark_param_as_ready(param) 
        return param_hook

    def _post_backward(self):
        """
        A post-backward callback that waits for gradient synchronization to finish, then copies 
        the synchronized gradients back to the parameters' grad attribute.
        
        This method is called after the backward pass and before the optimizer step.
        """
        self.bucket_manager.wait()
        self._post_backward_callback_set = False
        # copy to params.grad so we can use the optimizer to update the parameters
        for p in self.module.parameters():
            if p.requires_grad:
                p.grad = p.main_grad.to(p.dtype) # In PyTorch, you cannot assign a gradient with one data type to a tensor of another data type.

    def reset(self):
        self.bucket_manager.reset()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



