optimum/habana/distributed/fast_ddp.py (83 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Copyright (C) 2023 Habana Labs, Ltd. an Intel Company
###############################################################################
"""
Fast and lightweight alternative to DistributeDataParallel for Habana Gaudi
"""
import torch
def all_reduce_gradients(
model: torch.nn.Module, fusion_buffer_dtype: torch.dtype = torch.bfloat16, use_hpu_graphs: bool = True
):
"""
Invokes an all-reduce operation on the gradients supporting data parallel training.
This function is meant to be called after forward+backward passes, where the gradient information is available in the model parameters.
Once called, the list of gradients participating in the training process must remain the same.
Args:
model (torch.nn.Module): A model whose gradients are meant to be all-reduced.
fusion_buffer_dtype (torch.dtype): The dtype of internally allocated gradient fusion buffer.
use_hpu_graphs (bool): Determines whether HPU graph recording should be used for packing and unpacking the gradients.
Raises:
NotImplementedError: `all_reduce_gradients()` does not support changing the set of active gradients after first invocation.
"""
# Try to get the existing fusion buffer created for the model.
fusion_entries = model.__dict__.get("_all_reduce_fusion_entries", None)
if fusion_entries is not None:
if len(fusion_entries) == 0:
# There is nothing to all-reduce, neither the fusion buffer.
return
fusion_buffer = model._all_reduce_fusion_buffer
if use_hpu_graphs:
pack_graph = model._all_reduce_gradient_pack_graph
unpack_graph = model._all_reduce_gradient_unpack_graph
else:
# Count the total number of elements of the reduced gradients.
grad_elem_count = 0
for param in model.parameters():
if param.grad is None:
continue
grad_elem_count += torch.numel(param.grad)
# There is nothing to all-reduce.
if grad_elem_count == 0:
model.__dict__["_all_reduce_fusion_entries"] = []
return
# Allocate the fusion buffer and associate it with the model.
fusion_buffer = torch.zeros(size=(grad_elem_count,), dtype=fusion_buffer_dtype, device="hpu:0")
model.__dict__["_all_reduce_fusion_buffer"] = fusion_buffer
# Build the fusion information necessary for gradient packing and unpacking processes.
grad_elem_count = 0
fusion_entries = []
for param in model.parameters():
if param.grad is None:
continue
grad_numel = torch.numel(param.grad)
fused_view = fusion_buffer[grad_elem_count : grad_elem_count + grad_numel].reshape(param.grad.shape)
fusion_entries.append((param, fused_view))
grad_elem_count += grad_numel
model.__dict__["_all_reduce_fusion_entries"] = fusion_entries
# Instruct the following logic to record packing and unpacking HPU graphs based on the newly created fusion buffer.
if use_hpu_graphs:
pack_graph = None
unpack_graph = None
# Pack the gradients into the fusion buffer.
def pack_grads():
world_size_inv = 1.0 / torch.distributed.group.WORLD.size()
for param, fused_view in fusion_entries:
grad = param.grad
if grad is None:
raise NotImplementedError(
"`all_reduce_gradients()` does not support changing the set of active gradients after first invocation."
)
if grad.dtype != fusion_buffer_dtype:
grad = grad.to(fusion_buffer_dtype)
grad = grad * world_size_inv
fused_view.copy_(grad, non_blocking=True)
if use_hpu_graphs:
if pack_graph is None:
import habana_frameworks.torch as ht
pack_graph = ht.hpu.HPUGraph()
with ht.hpu.stream(ht.hpu.Stream()):
pack_graph.capture_begin()
pack_grads()
pack_graph.capture_end()
model.__dict__["_all_reduce_gradient_pack_graph"] = pack_graph
pack_graph.replay()
else:
pack_grads()
# Invoke an all-reduce operation of the fused gradients.
torch.distributed.all_reduce(fusion_buffer, group=torch.distributed.group.WORLD, async_op=True)
# Unpack the gradients back to the model parameters.
def unpack_grads():
for param, fused_view in fusion_entries:
grad = param.grad
if grad is None:
raise NotImplementedError(
"`all_reduce_gradients()` does not support changing the set of active gradients after first invocation."
)
if fused_view.dtype != grad.dtype:
fused_view = fused_view.to(grad.dtype)
grad.copy_(fused_view, non_blocking=True)
if use_hpu_graphs:
if unpack_graph is None:
import habana_frameworks.torch as ht
unpack_graph = ht.hpu.HPUGraph()
with ht.hpu.stream(ht.hpu.Stream()):
unpack_graph.capture_begin()
unpack_grads()
unpack_graph.capture_end()
model.__dict__["_all_reduce_gradient_unpack_graph"] = unpack_graph
unpack_graph.replay()
else:
unpack_grads()