def model_predict()

in mico/utils/utils.py [0:0]


def model_predict(model, text, tokenizer, max_length, pooling_strategy, selected_layer_idx, device=None):
    """This function generates BERT representation.

    Parameters
    ----------
    model : BertModel
        This is the BERT model.
    text : string
        This is the input text.
    tokenizer : `BertTokenizerFast`
        This is the corresponding tokenizer for the BERT model.
    max_length : int
        This is the maximum of the length of the tokenized sequence.
        If a sentence is too long, the part outside this length will be trucated.
        FYI, the length of the tokenized results is usually larger than the number of words in the sentence. 
    pooling_strategy : string
        This is for the output. 
        If `pooling_strategy=CLS_TOKEN`, we only use the output of the first [CLS] token.
        If `pooling_strategy=REDUCE_MEAN`, we only use the average output of all the tokens.
    selected_layer_idx : int
        This is for the output. We collect the output from which layer. 
        Layer -1 means the last layer. BERT base model has 12 layers.
    device : int (GPU index) or string ('cpu' or 'cuda')
        This is the GPU index that BERT model is on.
    
    Returns
    -------
    BERT representation : `torch.tensor`
        A tensor with 768 dimension (for the BERT base model).
    """
    encoded_input = tokenizer(list(text), return_tensors='pt', padding=True, truncation=True, max_length=max_length)
    encoded_input = assign_gpu(encoded_input, device=device)
    outputs = model(**encoded_input)
    hidden_states_layer = outputs[2][selected_layer_idx]
    if pooling_strategy == 'CLS_TOKEN':
        return hidden_states_layer[:, 0, :]
    elif pooling_strategy == 'REDUCE_MEAN':
        unmasked_indice = (encoded_input['attention_mask'] == 1).float()
        return torch.mul(1 / torch.sum(unmasked_indice, axis=1).unsqueeze(1),
                         torch.einsum('ijk, ij -> ik', hidden_states_layer, unmasked_indice))
    else:
        error_message = "--pooling_strategy=%s is not supported. Please use 'CLS_TOKEN' or 'REDUCE_MEAN'." \
                        % pooling_strategy
        raise NotImplementedError(error_message)