chatlearn/utils/dist_utils.py (124 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""distributed utils"""
from collections import defaultdict
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
def bucket_tensors(tensors, bucket_size_mb):
"""Group tensors into chunks. We seperate sparse and dense tensor,
each containing tensors of same type up to certain byte limit in total size.
Args:
tensors (Sequence): A sequence of tensors to be separated into chunks.
size_limit (int): The limit of each chunk in bytes.
Return:
dense_buckets: Blocks of tensors of same type and within size_limit.
sparse_bucket: A list of sparse tensors
"""
size_limit = bucket_size_mb * 1024 * 1024
buf_dict = defaultdict(lambda: [[], 0])
dense_buckets = []
sparse_bucket = []
for tensor in tensors:
if tensor.is_sparse:
sparse_bucket.append(tensor)
continue
t = tensor.type()
size = tensor.numel() * tensor.element_size()
buf_and_size = buf_dict[t]
if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison
dense_buckets.append(buf_and_size[0])
buf_and_size = buf_dict[t] = [[], 0]
buf_and_size[0].append(tensor)
buf_and_size[1] += size
for buf, _ in buf_dict.values():
if len(buf) > 0:
dense_buckets.append(buf)
return dense_buckets, sparse_bucket
def bucket_tensors_two_stage_generator(tensor_generator, bucket_size_mb, stage2=False, tensor_changed=False):
"""Group tensors into chunks. We seperate sparse and dense tensor,
each containing tensors of same type up to certain byte limit in total size.
Args:
tensor_generator (Generator): A generator of tensors to be separated into chunks.
size_limit (int): The limit of each chunk in bytes.
Yield:
bucket_or_tensor: a bucket of tensor with same type and within size_limit, or a sparse tensor.
is_dense: whether the bucket_or_tensor is a dense-tensor bucket or sparse tensor.
"""
size_limit = bucket_size_mb * 1024 * 1024
buf_dict = defaultdict(lambda: [[], 0])
for tensor, buffer_num in tensor_generator():
if tensor.is_sparse:
yield tensor, False
continue
buffer_multiple = 1 if stage2 else buffer_num
t = tensor.type()
# expand buffer size of dst ranks which recv tensor from trainer.
size = tensor.numel() * tensor.element_size() * buffer_multiple
buf_and_size = buf_dict[t]
if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison
yield buf_and_size[0], True
buf_and_size = buf_dict[t] = [[], 0]
if tensor_changed and buffer_multiple > 1:
empty_or_curr_tensor = torch.empty(
size=[tensor.numel() * buffer_multiple],
dtype=tensor.dtype,
device=tensor.device
)
else:
empty_or_curr_tensor = tensor
buf_and_size[0].append((
empty_or_curr_tensor,
[size // tensor.element_size(), buffer_multiple, tensor]
))
buf_and_size[1] += size
for buf, size in buf_dict.values():
if len(buf) > 0:
yield buf, True
def unflatten_dense_tensors(flat_tensors, tensors, sizes, num_ranks):
all_buffers = defaultdict(list)
offset = 0
for size_multiple, tensor in zip(sizes, tensors):
size, multiple, orig_tensor = size_multiple
assert offset <= flat_tensors.numel()
assert len(flat_tensors.shape) == 1
flat_tensor = flat_tensors[offset:offset+size]
per_size = size // multiple
for rank in range(num_ranks):
if multiple > 1:
assert (flat_tensor.numel() // multiple) == tensor.numel(), \
f"flat_tensor: {flat_tensor.shape} should be {multiple} times of tensor {orig_tensor.shape}, \
per_size: {per_size} total_size: {size} num_ranks: {num_ranks} offset: {offset}"
all_buffers[rank].append(flat_tensor[rank * per_size:(rank + 1) * per_size].view(orig_tensor.shape))
else:
assert flat_tensor.numel() == orig_tensor.numel(), \
f"flat_tensor: {flat_tensor.shape} orig_tensor: {orig_tensor.shape}"
all_buffers[rank].append(flat_tensor.view(orig_tensor.shape))
del flat_tensor
offset += size
del flat_tensors
return all_buffers
def coalesced_comm_dense(bucket, comm_call, extra_args, tensor_changed=True):
"""
coalesced communication for dense parameters
"""
flat_tensors = _flatten_dense_tensors(bucket)
comm_call(flat_tensors, *extra_args)
if tensor_changed:
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def coalesced_comm_dense_two_stage(bucket, comm_call, rank, extra_args, tensor_changed=True, stage2=False, index=0):
"""
coalesced communication for dense parameters
"""
all_tensors = []
all_sizes = []
num_ranks = 1
orig_tensor_ele = 0
orig_tensors = []
for tensor, size in bucket:
all_tensors.append(tensor)
all_sizes.append(size)
orig_tensors.append(size[2])
orig_tensor_ele += size[2].numel()
num_ranks = max(num_ranks, size[1])
flat_tensors = _flatten_dense_tensors(all_tensors)
del all_tensors
comm_call(flat_tensors, *extra_args)
if tensor_changed:
index = 0 if stage2 else index
all_buffers = unflatten_dense_tensors(flat_tensors, orig_tensors, all_sizes, num_ranks)
for tensor, synced in zip(orig_tensors, all_buffers[index]):
assert tensor.numel() == synced.numel(), \
f"rank {rank} tensor {tensor.shape} should be equal to synced.shape {synced.shape}, for all_sizes {all_sizes}"
tensor.copy_(synced)
del all_buffers[index]
return all_buffers
return None
def broadcast_var_object_dict(obj_dict, src_rank):
if torch.distributed.get_rank() == src_rank:
dict_as_list = list(obj_dict.items())
list_length = len(dict_as_list)
length_tensor = torch.tensor(list_length, device='cuda')
torch.distributed.broadcast(length_tensor, src_rank)
torch.distributed.broadcast_object_list(dict_as_list, src=src_rank)
return obj_dict
else:
length_tensor = torch.tensor(0, device='cuda')
torch.distributed.broadcast(length_tensor, src_rank)
list_length = length_tensor.item()
dict_as_list = [None] * list_length
torch.distributed.broadcast_object_list(dict_as_list, src=src_rank)
return dict(dict_as_list)