in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]
def get_embedding_layer(model):
if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
return model.transformer.wte
elif isinstance(model, LlamaForCausalLM):
return model.model.embed_tokens
elif isinstance(model, GPTNeoXForCausalLM):
return model.base_model.embed_in
elif isinstance(model, Phi3ForCausalLM):
return model.model.embed_tokens
else:
raise ValueError(f"Unknown model type: {type(model)}")