in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]
def get_embeddings(model, input_ids):
if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
return model.transformer.wte(input_ids).half()
elif isinstance(model, LlamaForCausalLM):
return model.model.embed_tokens(input_ids)
elif isinstance(model, GPTNeoXForCausalLM):
return model.base_model.embed_in(input_ids).half()
elif isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM):
return model.model.embed_tokens(input_ids)
elif isinstance(model, Phi3ForCausalLM):
return model.model.embed_tokens(input_ids)
else:
raise ValueError(f"Unknown model type: {type(model)}")