in local_gemma/attention.py [0:0]
def load_hook(self, state_dict, prefix, *args):
if prefix + "q_proj.weight" in state_dict:
q_proj = state_dict.pop(prefix + "q_proj.weight")
k_proj = state_dict.pop(prefix + "k_proj.weight")
v_proj = state_dict.pop(prefix + "v_proj.weight")
state_dict[prefix + "qkv_proj.weight"] = torch.cat([q_proj, k_proj, v_proj])
if self.config.attention_bias:
q_bias = state_dict.pop(prefix + "q_proj.bias")
k_bias = state_dict.pop(prefix + "k_proj.bias")
v_bias = state_dict.pop(prefix + "v_proj.bias")
state_dict[prefix + "qkv_proj.bias"] = torch.cat([q_bias, k_bias, v_bias])