in backends/gaudi/server/text_generation_server/utils/quantization.py [0:0]
def _get_quantizer_config(model_id, revision):
bits = 4
groupsize = -1
quant_method = "gptq"
checkpoint_format = None
sym = False
desc_act = False
weight_block_size = None
modules_to_not_convert = []
filename = "config.json"
try:
data = _get_config_json(model_id, revision, filename)
# FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
)
weight_block_size = data["quantization_config"].get("weight_block_size", None)
if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"]
quant_method = "awq"
elif "sym" in data["quantization_config"]:
sym = data["quantization_config"]["sym"]
bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", []
)
if modules_to_not_convert is None:
modules_to_not_convert = []
except Exception:
filename = "quantize_config.json"
try:
data = _get_config_json(model_id, revision, filename)
bits = data["bits"]
groupsize = data["group_size"]
if "zero_point" in data:
sym = not data["zero_point"]
quant_method = "awq"
elif "sym" in data:
sym = data["sym"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
data = _get_config_json(model_id, revision, filename)
bits = data["w_bit"]
groupsize = data["q_group_size"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
pass
return _QuantizerConfig(
bits=bits,
groupsize=groupsize,
quant_method=quant_method,
checkpoint_format=checkpoint_format,
sym=sym,
desc_act=desc_act,
weight_block_size=weight_block_size,
modules_to_not_convert=modules_to_not_convert,
)