def get_max_num_lbl_tok()

in scripts/adapet/ADAPET/setfit_adapet.py [0:0]


def get_max_num_lbl_tok(task_name, train_ds, pretrained_weight, lang_star_dict=None):
    tokenizer = AutoTokenizer.from_pretrained(pretrained_weight)
    if task_name in AMZ_MULTI_LING and lang_star_dict:
        labels = set(lang_star_dict.values())
    else:
        labels = set(train_ds['label_text'])
    tokens = [tokenizer.encode(lab) for lab in labels]
    
    max_tokens = 0
    for t in tokens:
        num_tokens = len(t) - 2 #remove cls and sep token
        if num_tokens > max_tokens:
            max_tokens = num_tokens

    return max_tokens