in torchrec/distributed/quant_embedding_kernel.py [0:0]
def from_float(cls, module: BaseEmbeddingBag) -> "QuantBatchedEmbeddingBag":
assert hasattr(
module, "qconfig"
), "EmbeddingBagCollectionInterface input float module must have qconfig defined"
def _to_data_type(dtype: torch.dtype) -> DataType:
if dtype == torch.quint8 or dtype == torch.qint8:
return DataType.INT8
elif dtype == torch.quint4 or dtype == torch.qint4:
return DataType.INT4
elif dtype == torch.quint2 or dtype == torch.qint2:
return DataType.INT2
else:
raise Exception(f"Invalid data type {dtype}")
# pyre-ignore [16]
data_type = _to_data_type(module.qconfig.weight().dtype)
sparse_type = QuantBatchedEmbeddingBag.to_sparse_type(data_type)
state_dict = dict(
itertools.chain(module.named_buffers(), module.named_parameters())
)
device = next(iter(state_dict.values())).device
# Adjust config to quantized version.
# This obviously doesn't work for column-wise sharding.
# pyre-ignore [29]
config = copy.deepcopy(module.config())
config.data_type = data_type
for table in config.embedding_tables:
table.local_cols = rounded_row_size_in_bytes(table.local_cols, sparse_type)
if table.local_metadata is not None:
table.local_metadata.shard_sizes = [
table.local_rows,
table.local_cols,
]
if table.global_metadata is not None:
for shard_meta in table.global_metadata.shards_metadata:
if shard_meta != table.local_metadata:
shard_meta.shard_sizes = [
shard_meta.shard_sizes[0],
rounded_row_size_in_bytes(
shard_meta.shard_sizes[1], sparse_type
),
]
table.global_metadata.size = torch.Size(
[
table.global_metadata.size[0],
sum(
shard_meta.shard_sizes[1]
for shard_meta in table.global_metadata.shards_metadata
),
]
)
ret = QuantBatchedEmbeddingBag(config=config, device=device)
# Quantize weights.
quant_weight_list = []
for _, weight in state_dict.items():
quantized_weights = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
weight, DATA_TYPE_NUM_BITS[data_type]
)
# weight and 4 byte scale shift (2xfp16)
quant_weight = quantized_weights[:, :-4]
scale_shift = quantized_weights[:, -4:]
quant_weight_list.append((quant_weight, scale_shift))
ret.emb_module.assign_embedding_weights(quant_weight_list)
return ret