def from_float()

in torchrec/distributed/quant_embedding_kernel.py [0:0]


    def from_float(cls, module: BaseEmbeddingBag) -> "QuantBatchedEmbeddingBag":
        assert hasattr(
            module, "qconfig"
        ), "EmbeddingBagCollectionInterface input float module must have qconfig defined"

        def _to_data_type(dtype: torch.dtype) -> DataType:
            if dtype == torch.quint8 or dtype == torch.qint8:
                return DataType.INT8
            elif dtype == torch.quint4 or dtype == torch.qint4:
                return DataType.INT4
            elif dtype == torch.quint2 or dtype == torch.qint2:
                return DataType.INT2
            else:
                raise Exception(f"Invalid data type {dtype}")

        # pyre-ignore [16]
        data_type = _to_data_type(module.qconfig.weight().dtype)
        sparse_type = QuantBatchedEmbeddingBag.to_sparse_type(data_type)

        state_dict = dict(
            itertools.chain(module.named_buffers(), module.named_parameters())
        )
        device = next(iter(state_dict.values())).device

        # Adjust config to quantized version.
        # This obviously doesn't work for column-wise sharding.
        # pyre-ignore [29]
        config = copy.deepcopy(module.config())
        config.data_type = data_type
        for table in config.embedding_tables:
            table.local_cols = rounded_row_size_in_bytes(table.local_cols, sparse_type)
            if table.local_metadata is not None:
                table.local_metadata.shard_sizes = [
                    table.local_rows,
                    table.local_cols,
                ]

            if table.global_metadata is not None:
                for shard_meta in table.global_metadata.shards_metadata:
                    if shard_meta != table.local_metadata:
                        shard_meta.shard_sizes = [
                            shard_meta.shard_sizes[0],
                            rounded_row_size_in_bytes(
                                shard_meta.shard_sizes[1], sparse_type
                            ),
                        ]
                table.global_metadata.size = torch.Size(
                    [
                        table.global_metadata.size[0],
                        sum(
                            shard_meta.shard_sizes[1]
                            for shard_meta in table.global_metadata.shards_metadata
                        ),
                    ]
                )

        ret = QuantBatchedEmbeddingBag(config=config, device=device)

        # Quantize weights.
        quant_weight_list = []
        for _, weight in state_dict.items():
            quantized_weights = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
                weight, DATA_TYPE_NUM_BITS[data_type]
            )
            # weight and 4 byte scale shift (2xfp16)
            quant_weight = quantized_weights[:, :-4]
            scale_shift = quantized_weights[:, -4:]

            quant_weight_list.append((quant_weight, scale_shift))
        ret.emb_module.assign_embedding_weights(quant_weight_list)

        return ret