def get_weights()

in backends/gaudi/server/text_generation_server/layers/gptq/__init__.py [0:0]


    def get_weights(self, weights: Weights, prefix: str):
        self._get_gptq_params(weights)

        use_exllama = True
        if self.bits != 4:
            use_exllama = False

        if self.desc_act:
            log_once(logger.warning, "Disabling exllama because desc_act=True")
            use_exllama = False

        if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
            return DefaultWeightsLoader.get_weights(weights, prefix)

        try:
            qweight = weights.get_tensor(f"{prefix}.qweight")
        except RuntimeError:
            raise RuntimeError(
                "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
            )

        if self.quantize == "gptq" and self.quant_method == "gptq":
            g_idx = weights.get_tensor(f"{prefix}.g_idx")
        else:
            g_idx = None

        qzeros = weights.get_tensor(f"{prefix}.qzeros")
        scales = weights.get_tensor(f"{prefix}.scales")

        if use_exllama and g_idx is not None:
            g_idx = g_idx - g_idx[0]

        if self.quantize == "gptq" and self.quant_method == "awq":
            log_once(
                logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
            )
            from text_generation_server.layers.awq.conversion_utils import (
                fast_awq_to_gptq,
            )

            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
            if use_exllama:
                g_idx = None
            else:
                g_idx = (
                    torch.arange(
                        qweight.shape[0] * (32 // self.bits),
                        device=qweight.device,
                    )
                    // self.groupsize
                ).to(dtype=torch.int32)

        return GPTQWeight(
            qweight=qweight,
            qzeros=qzeros,
            scales=scales,
            g_idx=g_idx,
            bits=self.bits,
            groupsize=self.groupsize,
            use_exllama=use_exllama,
        )