in src/rime/models/zero_shot/item_knn.py [0:0]
def __init__(self, item_df, batch_size=100,
item_pop_power=1, item_pop_pseudo=0.01,
model_name='bert-base-uncased', # gpt2
pooling=None, # cls or mean
temperature=None, gamma=0.5):
assert "TITLE" in item_df or "embedding" in item_df, "require TITLE or embedding"
self.item_index = item_df.index
self.item_biases = item_pop_power * np.log(item_df['_hist_len'].values + item_pop_pseudo)
self.batch_size = batch_size
if temperature is None:
temperature = {
'bert-base-uncased': 10,
'gpt2': 100,
}[model_name]
if pooling is None:
pooling = 'cls' if 'bert' in model_name else 'mean'
self.pooling = pooling
self.temperature = temperature
self.gamma = gamma
if "embedding" in item_df:
self.item_embeddings = np.vstack(item_df["embedding"].values)
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
assert self.tokenizer.padding_side == 'right', "expect right padding"
if model_name == 'gpt2':
self.tokenizer.pad_token = self.tokenizer.eos_token
if model_name.startswith('bert'):
self.model = BertForMaskedLM.from_pretrained(model_name)
else: # gpt2
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.model.eval() # eval mode
self.item_embeddings = self._compute_embeddings(item_df["TITLE"].values)