fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py (88 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import enum from typing import Dict @enum.unique class EmbOptimType(enum.Enum): SGD = "sgd" # uses non-deterministic updates (atomicAdd(..)) with duplicate ids EXACT_SGD = ( "exact_sgd" # uses deterministic updates (via sorting + segment reduction) ) LAMB = "lamb" ADAM = "adam" # exact/dedup: gradients to the same row are applied with coalesce then apply # together, instead of applied in sequence (approx). EXACT_ADAGRAD = "exact_adagrad" EXACT_ROWWISE_ADAGRAD = "exact_row_wise_adagrad" LARS_SGD = "lars_sgd" PARTIAL_ROWWISE_ADAM = "partial_row_wise_adam" PARTIAL_ROWWISE_LAMB = "partial_row_wise_lamb" ROWWISE_ADAGRAD = "row_wise_adagrad" MADGRAD = "madgrad" EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" def __str__(self) -> str: return self.value @enum.unique class SparseType(enum.Enum): FP32 = "fp32" FP16 = "fp16" INT8 = "int8" INT4 = "int4" INT2 = "int2" BF16 = "bf16" def __str__(self) -> str: return self.value @staticmethod def from_int(ty: int) -> "SparseType": if ty == 0: return SparseType("fp32") elif ty == 1: return SparseType("fp16") elif ty == 2: return SparseType("int8") elif ty == 3: return SparseType("int4") elif ty == 4: return SparseType("int2") elif ty == 5: return SparseType("bf16") else: raise ValueError(f"Unsupported sparse type: {ty}") def as_int(self) -> int: return { SparseType.FP32.value: 0, SparseType.FP16.value: 1, SparseType.INT8.value: 2, SparseType.INT4.value: 3, SparseType.INT2.value: 4, SparseType.BF16.value: 5, }[self.value] def bit_rate(self) -> int: return { SparseType.FP32.value: 32, SparseType.FP16.value: 16, SparseType.INT8.value: 8, SparseType.INT4.value: 4, SparseType.INT2.value: 2, SparseType.BF16.value: 16, }[self.value] def align_size(self) -> int: return { SparseType.FP32.value: 1, SparseType.FP16.value: 2, SparseType.INT8.value: 4, SparseType.INT4.value: 8, SparseType.INT2.value: 16, SparseType.BF16.value: 2, }[self.value] def is_float(self) -> bool: if ( self.value == SparseType.FP32.value or self.value == SparseType.FP16.value or self.value == SparseType.BF16.value ): return True else: return False ELEMENT_SIZE: Dict[SparseType, int] = { SparseType.FP32: 4, SparseType.FP16: 2, SparseType.INT8: 1, SparseType.BF16: 2, # SparseType.INT4: 0.5, }