def __init__()

in src/rime/models/zero_shot/item_knn.py [0:0]


    def __init__(self, item_df, batch_size=100,
                 item_pop_power=1, item_pop_pseudo=0.01,
                 model_name='bert-base-uncased',  # gpt2
                 pooling=None,  # cls or mean
                 temperature=None, gamma=0.5):

        assert "TITLE" in item_df or "embedding" in item_df, "require TITLE or embedding"

        self.item_index = item_df.index
        self.item_biases = item_pop_power * np.log(item_df['_hist_len'].values + item_pop_pseudo)
        self.batch_size = batch_size
        if temperature is None:
            temperature = {
                'bert-base-uncased': 10,
                'gpt2': 100,
            }[model_name]
        if pooling is None:
            pooling = 'cls' if 'bert' in model_name else 'mean'
        self.pooling = pooling
        self.temperature = temperature
        self.gamma = gamma

        if "embedding" in item_df:
            self.item_embeddings = np.vstack(item_df["embedding"].values)
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            assert self.tokenizer.padding_side == 'right', "expect right padding"
            if model_name == 'gpt2':
                self.tokenizer.pad_token = self.tokenizer.eos_token

            if model_name.startswith('bert'):
                self.model = BertForMaskedLM.from_pretrained(model_name)
            else:  # gpt2
                self.model = AutoModelForCausalLM.from_pretrained(model_name)
            self.model.eval()  # eval mode
            self.item_embeddings = self._compute_embeddings(item_df["TITLE"].values)