def get_momentum()

in torchrec/distributed/batched_embedding_kernel.py [0:0]


                def get_momentum(momentum_idx: int) -> ShardedTensor:
                    assert momentum_idx > 0
                    momentum_local_shards: List[Shard] = []

                    sharded_tensor_metadata = table_config.global_metadata
                    for (optimizer_state, shard_param_local_metadata) in zip(
                        shard_params.optimizer_states, shard_params.local_metadata
                    ):

                        local_metadata = table_config.local_metadata

                        if optimizer_state[momentum_idx - 1].dim() == 1:
                            (
                                local_metadata,
                                sharded_tensor_metadata,
                            ) = to_rowwise_sharded_metadata(
                                shard_param_local_metadata,
                                table_config.global_metadata,
                                sharding_dim,
                            )

                        assert local_metadata is not None
                        assert sharded_tensor_metadata is not None
                        momentum_local_shards.append(
                            Shard(optimizer_state[momentum_idx - 1], local_metadata)
                        )

                    return ShardedTensor._init_from_local_shards_and_global_metadata(
                        local_shards=momentum_local_shards,
                        sharded_tensor_metadata=sharded_tensor_metadata,
                        process_group=self._pg,
                    )