tzrec/utils/dist_util.py (216 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple
import torch
from torch import distributed as dist
from torch import nn
from torch.autograd.profiler import record_function
from torchrec.distributed import embeddingbag
from torchrec.distributed.embedding_types import (
KJTList,
)
from torchrec.distributed.embeddingbag import (
ShardedEmbeddingBagCollection,
)
from torchrec.distributed.mc_embedding_modules import (
BaseShardedManagedCollisionEmbeddingCollection,
ShrdCtx,
)
from torchrec.distributed.model_parallel import DataParallelWrapper
from torchrec.distributed.model_parallel import (
DistributedModelParallel as _DistributedModelParallel,
)
from torchrec.distributed.types import (
Awaitable,
ModuleSharder,
ShardingEnv,
ShardingPlan,
)
from torchrec.distributed.utils import none_throws
from torchrec.modules.embedding_configs import PoolingType
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, _to_offsets
def broadcast_string(s: str, src: int = 0) -> str:
"""Broadcasts a string from the source rank to all other ranks."""
if dist.get_rank() == src:
s_tensor = torch.ByteTensor(bytearray(s, "utf-8"))
length = torch.tensor([len(s_tensor)])
else:
length = torch.tensor([0], dtype=torch.long)
if dist.get_backend() == dist.Backend.NCCL:
length = length.cuda()
dist.broadcast(length, src)
if dist.get_rank() != src:
s_tensor = torch.ByteTensor(length.item())
if dist.get_backend() == dist.Backend.NCCL:
s_tensor = s_tensor.cuda()
# pyre-ignore [61]
dist.broadcast(s_tensor, src)
s_recv = s_tensor.cpu().numpy().tobytes().decode("utf-8")
return s_recv
def gather_strings(s: str, dst: int = 0) -> List[str]:
"""Gather strings from all ranks to the destination rank."""
rank = dist.get_rank()
world_size = dist.get_world_size()
s_tensor = torch.ByteTensor(bytearray(s, "utf-8"))
max_len = torch.tensor([len(s_tensor)], dtype=torch.long)
max_len_list = [torch.tensor([0], dtype=torch.long) for _ in range(world_size)]
if dist.get_backend() == dist.Backend.NCCL:
max_len = max_len.cuda()
max_len_list = [x.cuda() for x in max_len_list]
dist.all_gather(max_len_list, max_len)
# pyre-ignore [6]
max_len = max(max_len_list).item()
padded_s_tensor = torch.cat(
(s_tensor, torch.zeros(max_len - len(s_tensor), dtype=torch.uint8))
)
if rank == dst:
gather_list = [
torch.zeros(max_len, dtype=torch.uint8) for _ in range(world_size)
]
else:
gather_list = []
if dist.get_backend() == dist.Backend.NCCL:
padded_s_tensor = padded_s_tensor.cuda()
gather_list = [x.cuda() for x in gather_list]
dist.gather(padded_s_tensor, gather_list, dst)
gathered_strings = []
if rank == dst:
for tensor in gather_list:
string = tensor.cpu().numpy().tobytes().decode("utf-8").rstrip("\x00")
gathered_strings.append(string)
return gathered_strings
# lengths of kjt will be modified by create_mean_pooling_divisor, we fix it
# with lengths = lengths.clone() temporarily.
def _create_mean_pooling_divisor(
lengths: torch.Tensor,
keys: List[str],
offsets: torch.Tensor,
stride: int,
stride_per_key: List[int],
dim_per_key: torch.Tensor,
pooling_type_to_rs_features: Dict[str, List[str]],
embedding_names: List[str],
embedding_dims: List[int],
variable_batch_per_feature: bool,
kjt_inverse_order: torch.Tensor,
kjt_key_indices: Dict[str, int],
kt_key_ordering: torch.Tensor,
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
with record_function("## ebc create mean pooling callback ##"):
batch_size = (
none_throws(inverse_indices)[1].size(dim=1)
if variable_batch_per_feature
else stride
)
if weights is not None:
# if we have weights, lengths is the sum of weights by offsets for feature
lengths = torch.ops.fbgemm.segment_sum_csr(1, offsets.int(), weights)
if variable_batch_per_feature:
inverse_indices = none_throws(inverse_indices)
device = inverse_indices[1].device
inverse_indices_t = inverse_indices[1]
if len(keys) != len(inverse_indices[0]):
inverse_indices_t = torch.index_select(
inverse_indices[1], 0, kjt_inverse_order
)
offsets = _to_offsets(torch.tensor(stride_per_key, device=device))[
:-1
].unsqueeze(-1)
indices = (inverse_indices_t + offsets).flatten()
lengths = torch.index_select(input=lengths, dim=0, index=indices)
# only convert the sum pooling features to be 1 lengths
lengths = lengths.clone()
for feature in pooling_type_to_rs_features[PoolingType.SUM.value]:
feature_index = kjt_key_indices[feature]
feature_index = feature_index * batch_size
lengths[feature_index : feature_index + batch_size] = 1
if len(embedding_names) != len(keys):
lengths = torch.index_select(
lengths.reshape(-1, batch_size),
0,
kt_key_ordering,
).reshape(-1)
# transpose to align features with keyed tensor dim_per_key
lengths = lengths.reshape(-1, batch_size).T # [batch_size, num_features]
output_size = sum(embedding_dims)
divisor = torch.repeat_interleave(
input=lengths,
repeats=dim_per_key,
dim=1,
output_size=output_size,
)
eps = 1e-6 # used to safe guard against 0 division
divisor = divisor + eps
return divisor.detach()
# pyre-ignore [9]
embeddingbag._create_mean_pooling_divisor = _create_mean_pooling_divisor
# fix missing create_mean_pooling_callback of mc-ebc input_dist
def _mc_input_dist(
# pyre-ignore [2]
self,
ctx: ShrdCtx,
features: KeyedJaggedTensor,
) -> Awaitable[Awaitable[KJTList]]:
if self._embedding_module._has_uninitialized_input_dist:
if isinstance(self._embedding_module, ShardedEmbeddingBagCollection):
self._features_order = []
# disable feature permutation in mc, because we should
# permute features in mc-ebc before mean pooling callback.
if self._managed_collision_collection._has_uninitialized_input_dists:
self._managed_collision_collection._create_input_dists(
input_feature_names=features.keys()
)
self._managed_collision_collection._has_uninitialized_input_dists = (
False
)
if self._managed_collision_collection._features_order:
self._features_order = (
self._managed_collision_collection._features_order
)
self._managed_collision_collection._features_order = []
if self._embedding_module._has_mean_pooling_callback:
self._embedding_module._init_mean_pooling_callback(
features.keys(),
# pyre-ignore [16]
ctx.inverse_indices,
)
self._embedding_module._has_uninitialized_input_dist = False
if isinstance(self._embedding_module, ShardedEmbeddingBagCollection):
with torch.no_grad():
if self._features_order:
features = features.permute(
self._features_order,
self._managed_collision_collection._features_order_tensor,
)
if self._embedding_module._has_mean_pooling_callback:
ctx.divisor = _create_mean_pooling_divisor(
lengths=features.lengths(),
stride=features.stride(),
keys=features.keys(),
offsets=features.offsets(),
pooling_type_to_rs_features=self._embedding_module._pooling_type_to_rs_features,
stride_per_key=features.stride_per_key(),
dim_per_key=self._embedding_module._dim_per_key,
embedding_names=self._embedding_module._embedding_names,
embedding_dims=self._embedding_module._embedding_dims,
# pyre-ignore [16]
variable_batch_per_feature=ctx.variable_batch_per_feature,
kjt_inverse_order=self._embedding_module._kjt_inverse_order,
kjt_key_indices=self._embedding_module._kjt_key_indices,
kt_key_ordering=self._embedding_module._kt_key_ordering,
inverse_indices=ctx.inverse_indices,
weights=features.weights_or_none(),
)
# TODO: resolve incompatibility with different contexts
return self._managed_collision_collection.input_dist(
ctx,
features,
)
BaseShardedManagedCollisionEmbeddingCollection.input_dist = _mc_input_dist
def DistributedModelParallel(
module: nn.Module,
env: Optional[ShardingEnv] = None,
device: Optional[torch.device] = None,
plan: Optional[ShardingPlan] = None,
sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None,
init_data_parallel: bool = True,
init_parameters: bool = True,
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
) -> _DistributedModelParallel:
"""Entry point to model parallelism.
we custom ddp to make input_dist of ShardModel uninitialized.
mc-ebc now make _has_uninitialized_input_dist = True in init.
TODO: use torchrec DistributedModelParallel when torchrec fix it.
"""
model = _DistributedModelParallel(
module,
env,
device,
plan,
sharders,
init_data_parallel,
init_parameters,
data_parallel_wrapper,
)
for _, m in model.named_modules():
if hasattr(m, "_has_uninitialized_input_dist"):
m._has_uninitialized_input_dist = True
return model