in src/rime/models/zero_shot/bayes_lm.py [0:0]
def __init__(self, item_df, max_num_candidates=None, batch_size=100,
prompt="a user will watch {y} after watching {x}",
item_pop_power=1, item_pop_pseudo=0.01, temperature=1, gamma=0,
candidate_selection_method=None, model_name='gpt2', # bert-base-uncased
text_column_name='TITLE'):
assert text_column_name in item_df, f"require {text_column_name} as data(y)"
self.item_df = item_df.copy()
self.item_df['log_p_y'] = item_pop_power * np.log(item_df['_hist_len'] + item_pop_pseudo)
if max_num_candidates is None:
warnings.warn("please set max_num_candidates, default=2 only for testing purposes")
max_num_candidates = 2
self.max_num_candidates = max_num_candidates
self.batch_size = batch_size
self.prompt = prompt
self.temperature = temperature
self.gamma = gamma
self.text_column_name = text_column_name
if candidate_selection_method is None:
candidate_selection_method = 'greedy' if item_pop_power > 0 else 'sample'
self.candidate_selection_method = candidate_selection_method
# huggingface model initialization
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
assert self.tokenizer.padding_side == 'right', "expect right padding"
if model_name == 'gpt2':
self.tokenizer.pad_token = self.tokenizer.eos_token
if model_name.startswith('bert'):
self.model = BertForMaskedLM.from_pretrained(model_name)
else: # gpt2
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.model.eval() # eval mode
self.loss = torch.nn.CrossEntropyLoss(reduction='none')