def get_embedding_layer()

in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]


def get_embedding_layer(model):
    if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
        return model.transformer.wte
    elif isinstance(model, LlamaForCausalLM):
        return model.model.embed_tokens
    elif isinstance(model, GPTNeoXForCausalLM):
        return model.base_model.embed_in
    elif isinstance(model, Phi3ForCausalLM):
        return model.model.embed_tokens
    else:
        raise ValueError(f"Unknown model type: {type(model)}")