# SetFit SOTA for Bio Text Classification

SetFit is a great practical tool for few shot text classification, but did you know that you can fine-tune a vanilla SetFit for full-shot text classification and outperform models that were pre-trained from scratch using domain data?
Here we show such example in the Biological domain, where SetFit outperforms most of the models that were trained from scratch on Biological data, while being more efficient.

The following table summarizes the results of different models on the HoC* dataset. All of the biological models were first pre-trained using in-domain biological data and in addition were fine-tuned given the HoC training data in the BLUE benchmark. SetFit was not pre-trained using biological data, it is based on a general pre-trained sentence transformer model (MSFT's mpnet) and was solely fine-tuned on the HoC training data. As shown in the table, SetFit surpasses the Bio models and achieves comparable performance to the 347M BioGPT, which is the SOTA model for the Bio domain, while being 3x smaller: https://analyticsindiamag.com/microsoft-launches-biogpt-the-chatgpt-of-lifescience/

| **Model**               | **#params[M]** | **F1**  | **Pre-train Data**          | 
|:-----------------------:|:-------:|:---------------:|:-----------------:|
|  **BioBERT[1]**|    110    |   81.5          | Bio     
|  **PubMedBERT[2]**|    340    |   82.7          | Bio   
|    **BioLinkBERT[3]**       |    340   |   84.9          | Bio     
|    **GPT-2**             |    355 |   81.8     | General 
|    **BioGPT[4]**      |    347 |   85.1     | Bio
|       **SetFit**       |    105 |   **85.1** | General




Refrences:

[1] Domain-specific
language model pretraining for biomedical natural language
processing" https://arxiv.org/abs/2007.15779

[2] BioBERT: a pre-trained biomedical language representation
model for biomedical text mining" https://arxiv.org/abs/1901.08746

[3] LinkBERT: Pretraining Language Models with Document Links https://arxiv.org/abs/2203.15827

[4] BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining" https://arxiv.org/abs/2210.10341

[5] Automatic semantic classification of scientific literature according to
the hallmarks of cancer. https://academic.oup.com/bioinformatics/article/32/3/432/1743783

[6]  An
evaluation of BERT and ELMo on ten benchmarking
datasets https://arxiv.org/abs/1906.05474


*HoC (the Hallmarks of Cancers corpus) consists of 1580
PubMed abstracts manually annotated at sentence level by
experts with ten currently known hallmarks of cancer [5]. We follow the same training/test split as in [6]

### SetFit Multilabel HoC

In [None]:
!pip install setfit

Load the HoC dataset

In [None]:
!wget https://github.com/ncbi-nlp/BLUE_Benchmark/releases/download/0.1/data_v0.1.zip
!unzip data_v0.1.zip

import pandas as pd
import numpy as np

# Read train/test files
test_df = pd.read_csv('/content/data/hoc/test.tsv', sep='\t')
train_df = pd.read_csv('/content/data/hoc/train.tsv', sep='\t')

In [None]:
LABELS = ['activating invasion and metastasis', 'avoiding immune destruction',
          'cellular energetics', 'enabling replicative immortality', 'evading growth suppressors',
          'genomic instability and mutation', 'inducing angiogenesis', 'resisting cell death',
          'sustaining proliferative signaling', 'tumor promoting inflammation']

In [None]:
# Convert labels to hotvec multilabel format (similar to scikit-learn)
def hotvec_multilabel(true_df):
    data = {}

    for i in range(len(true_df)):
        true_row = true_df.iloc[i]

        key = true_row['index']

        data[key] = set()

        if not pd.isna(true_row['labels']):
            for l in true_row['labels'].split(','):
                data[key].add(LABELS.index(l))
                
    y_hotvec = []
    for k, (true) in data.items():
        t = [0] * len(LABELS)
        for i in true:
            t[i] = 1

        y_hotvec.append(t)

    y_hotvec = np.array(y_hotvec)

    return y_hotvec

### SetFit Multilabel

In [None]:
from datasets import Dataset
import evaluate
from setfit import SetFitModel, SetFitTrainer

model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2", 
    multi_target_strategy="multi-output",     # one-vs-rest; multi-output; classifier-chain
)

multilabel_f1_metric = evaluate.load("f1", "multilabel")
multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel")

# f1/accuracy sentence level
def compute_metrics(y_pred, y_test):
    return {
        "f1": multilabel_f1_metric.compute(predictions=y_pred, references=y_test, average="micro")["f1"],
        "accuracy": multilabel_accuracy_metric.compute(predictions=y_pred, references=y_test)["accuracy"],
    }

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [None]:
eval_dataset = Dataset.from_dict({"text": test_df['sentence'], "label": hotvec_multilabel(test_df)})
train_dataset = Dataset.from_dict({"text": train_df['sentence'], "label": hotvec_multilabel(train_df)})

In [None]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric=compute_metrics,
    num_iterations=5,
)

In [None]:
trainer.train()
metrics = trainer.evaluate()
print(metrics)

Applying column mapping to training dataset
***** Running training *****
  Num examples = 71200
  Num epochs = 1
  Total optimization steps = 4450
  Total train batch size = 16


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4450 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
Applying column mapping to evaluation dataset
***** Running evaluation *****


{'f1': 0.7284569138276554, 'accuracy': 0.836671270718232}


### Evaluation of BLUE's HoC F1 (abstract level)

Support functions refactored from https://github.com/ncbi-nlp/BLUE_Benchmark
can be downloaded at https://github.com/ncbi-nlp/BLUE_Benchmark/releases/tag/0.1

In [None]:
def divide(x, y):
    return np.true_divide(x, y, out=np.zeros_like(x, dtype=np.float64), where=y != 0)

def get_p_r_f_arrary(test_predict_label, test_true_label):
    num, cat = test_predict_label.shape
    acc_list = []
    prc_list = []
    rec_list = []
    f_score_list = []
    for i in range(num):
        label_pred_set = set()
        label_gold_set = set()

        for j in range(cat):
            if test_predict_label[i, j] == 1:
                label_pred_set.add(j)
            if test_true_label[i, j] == 1:
                label_gold_set.add(j)

        uni_set = label_gold_set.union(label_pred_set)
        intersec_set = label_gold_set.intersection(label_pred_set)

        tt = len(intersec_set)
        if len(label_pred_set) == 0:
            prc = 0
        else:
            prc = tt / len(label_pred_set)

        acc = tt / len(uni_set)

        rec = tt / len(label_gold_set)

        if prc == 0 and rec == 0:
            f_score = 0
        else:
            f_score = 2 * prc * rec / (prc + rec)

        acc_list.append(acc)
        prc_list.append(prc)
        rec_list.append(rec)
        f_score_list.append(f_score)

    mean_prc = np.mean(prc_list)
    mean_rec = np.mean(rec_list)
    f_score = divide(2 * mean_prc * mean_rec, (mean_prc + mean_rec))
    return mean_prc, mean_rec, f_score

def eval_hoc(true_df, pred_df):
    data = {}

    assert len(true_df) == len(pred_df), \
        f'Gold line no {len(true_df)} vs Prediction line no {len(pred_df)}'

    for i in range(len(true_df)):
        true_row = true_df.iloc[i]
        pred_row = pred_df.iloc[i]
        assert true_row['index'] == pred_row['index'], \
            'Index does not match @{}: {} vs {}'.format(i, true_row['index'], pred_row['index'])

        key = true_row['index'][:true_row['index'].find('_')]
        if key not in data:
            data[key] = (set(), set())

        if not pd.isna(true_row['labels']):
            for l in true_row['labels'].split(','):
                data[key][0].add(LABELS.index(l))

        if not pd.isna(pred_row['labels']):
            for l in pred_row['labels'].split(','):
                data[key][1].add(LABELS.index(l))

    assert len(data) == 315, 'There are 315 documents in the test set: %d' % len(data)

    y_test = []
    y_pred = []
    for k, (true, pred) in data.items():
        t = [0] * len(LABELS)
        for i in true:
            t[i] = 1

        p = [0] * len(LABELS)
        for i in pred:
            p[i] = 1

        y_test.append(t)
        y_pred.append(p)

    y_test = np.array(y_test)
    y_pred = np.array(y_pred)

    r, p, f1 = get_p_r_f_arrary(y_pred, y_test)
    print('Precision: {:.1f}'.format(p*100))
    print('Recall   : {:.1f}'.format(r*100))
    print('F1       : {:.1f}'.format(f1*100))

#### Evaluate on test data

In [None]:
test_predict_label = trainer.model.predict(test_df['sentence'])

In [None]:
# Convert hotvec multilabel to actual labels 
num, cat = test_predict_label.shape
sentence_list = []
for i in range(num):
    sentence_set = set()
    for j in range(cat):
        if test_predict_label[i, j] == 1:
            sentence_set.add(LABELS[j])
    sentence_list.append(','.join(sentence_set))

# Reformat for HoC evaluation
pred_df = test_df
pred_df = pred_df.assign(labels = sentence_list)
pred_df['labels'] = pred_df['labels'].replace({'':np.nan})
test_df['labels'] = test_df['labels'].replace({'':np.nan})

#### Evaluate F1 (abstract level)

In [None]:
eval_hoc(test_df, pred_df)

Precision: 86.4
Recall   : 83.8
F1       : 85.1
