in torchrec/quant/embedding_modules.py [0:0]
def to_sparse_type(data_type: DataType) -> SparseType:
if data_type == DataType.FP16:
return SparseType.FP16
elif data_type == DataType.INT8:
return SparseType.INT8
elif data_type == DataType.INT4:
return SparseType.INT4
elif data_type == DataType.INT2:
return SparseType.INT2
else:
raise ValueError(f"Invalid DataType {data_type}")