def get_workers()

in pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py [0:0]


def get_workers(params, eval=False):
    tokenizers = []
    for i in range(len(params.tokenizer_paths)):
        tokenizer = AutoTokenizer.from_pretrained(
            params.tokenizer_paths[i], token=params.token, trust_remote_code=False, **params.tokenizer_kwargs[i]
        )
        if "oasst-sft-6-llama-30b" in params.tokenizer_paths[i]:
            tokenizer.bos_token_id = 1
            tokenizer.unk_token_id = 0
        if "guanaco" in params.tokenizer_paths[i]:
            tokenizer.eos_token_id = 2
            tokenizer.unk_token_id = 0
        if "llama-2" in params.tokenizer_paths[i]:
            tokenizer.pad_token = tokenizer.unk_token
            tokenizer.padding_side = "left"
        if "falcon" in params.tokenizer_paths[i]:
            tokenizer.padding_side = "left"
        if "Phi-3-mini-4k-instruct" in params.tokenizer_paths[i]:
            tokenizer.bos_token_id = 1
            tokenizer.eos_token_id = 32000
            tokenizer.unk_token_id = 0
            tokenizer.padding_side = "left"
        if not tokenizer.pad_token:
            tokenizer.pad_token = tokenizer.eos_token
        tokenizers.append(tokenizer)

    print(f"Loaded {len(tokenizers)} tokenizers")

    raw_conv_templates = []
    for template in params.conversation_templates:
        if template in ["llama-2", "mistral", "llama-3-8b", "vicuna"]:
            raw_conv_templates.append(get_conversation_template(template)),
        elif template in ["phi-3-mini"]:
            conv_template = Conversation(
                name="phi-3-mini",
                system_template="<|system|>\n{system_message}",
                system_message="",
                roles=("<|user|>", "<|assistant|>"),
                sep_style=SeparatorStyle.CHATML,
                sep="<|end|>",
                stop_token_ids=[32000, 32001, 32007],
            )
            raw_conv_templates.append(conv_template)
        else:
            raise ValueError("Conversation template not recognized")

    conv_templates = []
    for conv in raw_conv_templates:
        if conv.name == "zero_shot":
            conv.roles = tuple(["### " + r for r in conv.roles])
            conv.sep = "\n"
        elif conv.name == "llama-2":
            conv.sep2 = conv.sep2.strip()
        conv_templates.append(conv)

    print(f"Loaded {len(conv_templates)} conversation templates")
    workers = [
        ModelWorker(
            params.model_paths[i],
            params.token,
            params.model_kwargs[i],
            tokenizers[i],
            conv_templates[i],
            params.devices[i],
        )
        for i in range(len(params.model_paths))
    ]
    if not eval:
        for worker in workers:
            worker.start()

    num_train_models = getattr(params, "num_train_models", len(workers))
    print("Loaded {} train models".format(num_train_models))
    print("Loaded {} test models".format(len(workers) - num_train_models))

    return workers[:num_train_models], workers[num_train_models:]