in src/jobs/tune_base.py [0:0]
def __init__(self, learning_rate: float = 1e-4, batch_size: int = 2, model_name: str = 'google/flan-t5-base',
label_column: str = "output", use_keywords: bool = True, single_tab_handling: bool = False,
learning_rate_decay: bool = True, shrink_remove_encoder_layers: int = 0, shrink_remove_decoder_layers: int = 0,
shrink_encoder_index_remove=None, shrink_decoder_index_remove=None, brevity_weight=None,
label_prefix=None,
shorten_training_label_boost=None,
teacher_model_artifact=None,
model_start_artifact=None,
training_data_files=None):
self.device = "cuda:0"
self.model_name = model_name
self.teacher_model_artifact = teacher_model_artifact
self.learning_rate = learning_rate
self.batch_size = batch_size
self.label_column = label_column
self.brevity_weight = brevity_weight
self.learning_rate_decay = learning_rate_decay
self.single_tab_handling = single_tab_handling
self.use_keywords = use_keywords
self.prompter = keyword_prompt if use_keywords else document_prompt
if self.single_tab_handling:
self.prompter = hybrid_prompt_gen
self.model = None
self.shrink_remove_encoder_layers = shrink_remove_encoder_layers
self.shrink_remove_decoder_layers = shrink_remove_decoder_layers
self.shrink_decoder_index_remove = shrink_decoder_index_remove
self.shrink_encoder_index_remove = shrink_encoder_index_remove
self.label_prefix = label_prefix
self.shorten_training_label_boost = shorten_training_label_boost
self.model_start_artifact = model_start_artifact
self.training_data_files = training_data_files