in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]
def get_embedding_matrix(model):
if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
return model.transformer.wte.weight
elif isinstance(model, LlamaForCausalLM):
return model.model.embed_tokens.weight
elif isinstance(model, GPTNeoXForCausalLM):
return model.base_model.embed_in.weight
elif isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM):
return model.model.embed_tokens.weight
elif isinstance(model, Phi3ForCausalLM):
return model.model.embed_tokens.weight
else:
raise ValueError(f"Unknown model type: {type(model)}")