in torchrec/distributed/cw_sharding.py [0:0]
def _init_combined_embeddings(self) -> None:
"""
Grabs the embedding names and dims from TwEmbeddingSharder.
Note:
This could have duplications if there are multiple shards from the same
table on a rank. Later on we process these to combine shards together.
"""
embedding_names: List[str] = super().embedding_names()
embedding_dims: List[int] = super().embedding_dims()
embedding_shard_metadata: List[
Optional[ShardMetadata]
] = super().embedding_shard_metadata()
embedding_name_to_index_offset_tuples: Dict[str, List[Tuple[int, int]]] = {}
for i, (name, metadata) in enumerate(
zip(embedding_names, embedding_shard_metadata)
):
if name not in embedding_name_to_index_offset_tuples:
embedding_name_to_index_offset_tuples[name] = []
embedding_name_to_index_offset_tuples[name].append(
(i, metadata.shard_offsets[1] if metadata is not None else 0)
)
embedding_name_to_index: Dict[str, List[int]] = {}
for name, index_offset_tuples in embedding_name_to_index_offset_tuples.items():
embedding_name_to_index[name] = [
idx_off_tuple[0]
for idx_off_tuple in sorted(
index_offset_tuples,
key=lambda idx_off_tuple: idx_off_tuple[1],
)
]
combined_embedding_names: List[str] = []
seen_embedding_names: Set[str] = set()
for name in embedding_names:
if name not in seen_embedding_names:
combined_embedding_names.append(name)
seen_embedding_names.add(name)
combined_embedding_dims: List[int] = []
embedding_order: List[int] = []
for name in combined_embedding_names:
combined_embedding_dims.append(
sum([embedding_dims[idx] for idx in embedding_name_to_index[name]])
)
embedding_order.extend(embedding_name_to_index[name])
self._embedding_names: List[str] = embedding_names
self._embedding_dims: List[int] = embedding_dims
self._embedding_order: List[int] = embedding_order
self._combined_embedding_names: List[str] = combined_embedding_names
self._combined_embedding_dims: List[int] = combined_embedding_dims