def to_sparse_type()

in torchrec/quant/embedding_modules.py [0:0]


        def to_sparse_type(data_type: DataType) -> SparseType:
            if data_type == DataType.FP16:
                return SparseType.FP16
            elif data_type == DataType.INT8:
                return SparseType.INT8
            elif data_type == DataType.INT4:
                return SparseType.INT4
            elif data_type == DataType.INT2:
                return SparseType.INT2
            else:
                raise ValueError(f"Invalid DataType {data_type}")