in torchrec/quant/embedding_modules.py [0:0]
def _to_data_type(dtype: torch.dtype) -> DataType:
if dtype == torch.quint8 or dtype == torch.qint8:
return DataType.INT8
elif dtype == torch.quint4 or dtype == torch.qint4:
return DataType.INT4
elif dtype == torch.quint2 or dtype == torch.qint2:
return DataType.INT2
else:
raise Exception(f"Invalid data type {dtype}")