def fuse_attention_weights()

in local_gemma/modeling_local_gemma_2.py [0:0]


def fuse_attention_weights(model: LocalGemma2ForCausalLM, device, torch_dtype) -> LocalGemma2ForCausalLM:
    for idx, layer in tqdm(enumerate(model.model.layers), desc="Fusing attention weights", total=model.config.num_hidden_layers):
        state_dict = layer.self_attn.state_dict()
        del layer.self_attn
        layer.self_attn = Gemma2FusedAttention(model.config, layer_idx=idx)
        # convert un-fused to fused through the pre-register hook
        layer.self_attn.load_state_dict(state_dict)
        layer.self_attn.to(device, dtype=torch_dtype)
    return model