def load_awq_quantized()

in fastchat/modules/awq.py [0:0]


def load_awq_quantized(model_name, awq_config: AWQConfig, device):
    print("Loading AWQ quantized model...")

    try:
        from tinychat.utils import load_quant
        from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
    except ImportError as e:
        print(f"Error: Failed to import tinychat. {e}")
        print("Please double check if you have successfully installed AWQ")
        print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md")
        sys.exit(-1)

    config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, use_fast=False, trust_remote_code=True
    )

    def skip(*args, **kwargs):
        pass

    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.kaiming_normal_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    modeling_utils._init_weights = False

    torch.set_default_dtype(torch.half)
    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)

    if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]):
        model = load_quant.load_awq_llama_fast(
            model,
            find_awq_ckpt(awq_config),
            awq_config.wbits,
            awq_config.groupsize,
            device,
        )
        make_quant_attn(model, device)
        make_quant_norm(model)
        make_fused_mlp(model)
    else:
        model = load_quant.load_awq_model(
            model,
            find_awq_ckpt(awq_config),
            awq_config.wbits,
            awq_config.groupsize,
            device,
        )
    return model, tokenizer