in torchrec/distributed/batched_embedding_kernel.py [0:0]
def get_momentum(momentum_idx: int) -> ShardedTensor:
assert momentum_idx > 0
momentum_local_shards: List[Shard] = []
sharded_tensor_metadata = table_config.global_metadata
for (optimizer_state, shard_param_local_metadata) in zip(
shard_params.optimizer_states, shard_params.local_metadata
):
local_metadata = table_config.local_metadata
if optimizer_state[momentum_idx - 1].dim() == 1:
(
local_metadata,
sharded_tensor_metadata,
) = to_rowwise_sharded_metadata(
shard_param_local_metadata,
table_config.global_metadata,
sharding_dim,
)
assert local_metadata is not None
assert sharded_tensor_metadata is not None
momentum_local_shards.append(
Shard(optimizer_state[momentum_idx - 1], local_metadata)
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=momentum_local_shards,
sharded_tensor_metadata=sharded_tensor_metadata,
process_group=self._pg,
)