def get_embeddings()

in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]


def get_embeddings(model, input_ids):
    if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
        return model.transformer.wte(input_ids).half()
    elif isinstance(model, LlamaForCausalLM):
        return model.model.embed_tokens(input_ids)
    elif isinstance(model, GPTNeoXForCausalLM):
        return model.base_model.embed_in(input_ids).half()
    elif isinstance(model, MixtralForCausalLM) or isinstance(model, MistralForCausalLM):
        return model.model.embed_tokens(input_ids)
    elif isinstance(model, Phi3ForCausalLM):
        return model.model.embed_tokens(input_ids)
    else:
        raise ValueError(f"Unknown model type: {type(model)}")