def setup()

in bench/generation/setup/awq.py [0:0]


def setup(model_id: str, weights: str, activations: str, group_size: int = 64, version="GEMV_FAST"):
    if activations != "none":
        raise ValueError("Activation quantization is not supported by HQQ")
    if weights != "int4":
        raise ValueError("AWQ only supports int4 weights.")
    quant_config = {"zero_point": True, "q_group_size": group_size, "w_bit": 4, "version": version}
    # Load model
    model = AutoAWQForCausalLM.from_pretrained(model_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    # Quantize
    model.quantize(tokenizer, quant_config=quant_config)
    # We need to save otherwise it doesn't work
    quant_path = model_id.replace("/", "-") + f"_{group_size}_{version}"
    model.save_quantized(quant_path)
    # Reload model
    model = AutoAWQForCausalLM.from_quantized(quant_path)
    # Hack: force transformers 4.36.2 behaviour
    model.model.prepare_inputs_for_generation = prepare_inputs_for_generation
    # Hack because AWQ models are not transformers models
    model.device = next(model.parameters()).device
    return model, tokenizer