in src/peft/tuners/multitask_prompt_tuning/model.py [0:0]
def __init__(self, config: MultitaskPromptTuningConfig, word_embeddings):
super().__init__(config, word_embeddings)
self.num_tasks = config.num_tasks
self.num_ranks = config.num_ranks
self.num_virtual_tokens = config.num_virtual_tokens
self.num_transformer_submodules = config.num_transformer_submodules
if self.num_transformer_submodules is None:
self.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1
self.token_dim = config.token_dim
total_virtual_tokens = self.num_virtual_tokens * self.num_transformer_submodules
self.prefix_task_cols = torch.nn.Parameter(
torch.normal(
mean=0,
std=0.02,
size=(self.num_tasks, total_virtual_tokens, self.num_ranks),
)
)
self.prefix_task_rows = torch.nn.Parameter(
torch.normal(
mean=0,
std=0.02,
size=(self.num_tasks, self.num_ranks, self.token_dim),
)
)
if config.prompt_tuning_init in [
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
MultitaskPromptTuningInit.EXACT_SOURCE_TASK,
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
]:
if config.prompt_tuning_init_state_dict_path is None:
raise ValueError(
f"prompt_tuning_init_state_dict_path needs to be specified with {config.prompt_tuning_init} "
"init method"
)
if config.prompt_tuning_init_state_dict_path.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict: dict = load_file(config.prompt_tuning_init_state_dict_path)
else:
state_dict: dict = torch_load(
config.prompt_tuning_init_state_dict_path,
map_location=word_embeddings.weight.device,
)
if config.prompt_tuning_init in [
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
MultitaskPromptTuningInit.EXACT_SOURCE_TASK,
]:
prefix_task_cols_: torch.Tensor = state_dict["prefix_task_cols"]
prefix_task_rows_: torch.Tensor = state_dict["prefix_task_rows"]
if config.prompt_tuning_init == MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS:
prefix_task_cols_ = prefix_task_cols_.mean(0, keepdim=True)
prefix_task_rows_ = prefix_task_rows_.mean(0, keepdim=True)
elif config.prompt_tuning_init == MultitaskPromptTuningInit.EXACT_SOURCE_TASK:
prefix_task_cols_ = prefix_task_cols_[config.prompt_tuning_init_task, ...].unsqueeze(0)
prefix_task_rows_ = prefix_task_rows_[config.prompt_tuning_init_task, ...].unsqueeze(0)
state_dict = {
"embedding.weight": state_dict["prompt_embeddings"],
"prefix_task_cols": prefix_task_cols_,
"prefix_task_rows": prefix_task_rows_,
}
self.load_state_dict(state_dict, strict=True)
elif config.prompt_tuning_init == MultitaskPromptTuningInit.ONLY_SOURCE_SHARED:
state_dict = {
"embedding.weight": state_dict["prompt_embeddings"],
}
self.load_state_dict(state_dict, strict=False)