def __init__()

in src/rime/models/zero_shot/bayes_lm.py [0:0]


    def __init__(self, item_df, max_num_candidates=None, batch_size=100,
                 prompt="a user will watch {y} after watching {x}",
                 item_pop_power=1, item_pop_pseudo=0.01, temperature=1, gamma=0,
                 candidate_selection_method=None, model_name='gpt2',  # bert-base-uncased
                 text_column_name='TITLE'):

        assert text_column_name in item_df, f"require {text_column_name} as data(y)"

        self.item_df = item_df.copy()
        self.item_df['log_p_y'] = item_pop_power * np.log(item_df['_hist_len'] + item_pop_pseudo)

        if max_num_candidates is None:
            warnings.warn("please set max_num_candidates, default=2 only for testing purposes")
            max_num_candidates = 2

        self.max_num_candidates = max_num_candidates
        self.batch_size = batch_size
        self.prompt = prompt
        self.temperature = temperature
        self.gamma = gamma
        self.text_column_name = text_column_name

        if candidate_selection_method is None:
            candidate_selection_method = 'greedy' if item_pop_power > 0 else 'sample'
        self.candidate_selection_method = candidate_selection_method

        # huggingface model initialization
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        assert self.tokenizer.padding_side == 'right', "expect right padding"
        if model_name == 'gpt2':
            self.tokenizer.pad_token = self.tokenizer.eos_token

        if model_name.startswith('bert'):
            self.model = BertForMaskedLM.from_pretrained(model_name)
        else:  # gpt2
            self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.model.eval()  # eval mode

        self.loss = torch.nn.CrossEntropyLoss(reduction='none')