in local_gemma/modeling_local_gemma_2.py [0:0]
def fuse_attention_weights(model: LocalGemma2ForCausalLM, device, torch_dtype) -> LocalGemma2ForCausalLM:
for idx, layer in tqdm(enumerate(model.model.layers), desc="Fusing attention weights", total=model.config.num_hidden_layers):
state_dict = layer.self_attn.state_dict()
del layer.self_attn
layer.self_attn = Gemma2FusedAttention(model.config, layer_idx=idx)
# convert un-fused to fused through the pre-register hook
layer.self_attn.load_state_dict(state_dict)
layer.self_attn.to(device, dtype=torch_dtype)
return model