modules/SwissArmyTransformer/sat/model/finetune/prompt_tuning.py (35 lines of code) (raw):
# -*- encoding: utf-8 -*-
'''
@File : prompt_tuning.py
@Time : 2021/12/12 20:45:18
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
from sat.transformer_defaults import attention_fn_default
from sat.model.base_model import BaseModel, BaseMixin, non_conflict
class PrefixTuningMixin(BaseMixin):
def __init__(self, num_layers, hidden_size_per_attention_head, num_attention_heads, prefix_len):
super().__init__()
self.prefix = torch.nn.ParameterList([
torch.nn.Parameter(torch.randn(2, num_attention_heads, prefix_len, hidden_size_per_attention_head)*0.01)
for layer_id in range(num_layers)
])
self.prefix_len = prefix_len
@non_conflict
def attention_fn(self, q, k, v, mask, dropout_fn, old_impl=attention_fn_default, **kw_args):
prefix_k, prefix_v = self.prefix[kw_args['layer_id']]
b, nh, seq_len, hidden_size = k.shape
prefix_k = prefix_k.unsqueeze(0).expand(b, nh, -1, hidden_size)
prefix_v = prefix_v.unsqueeze(0).expand(b, nh, -1, hidden_size)
k = torch.cat((k, prefix_k), dim=2)
v = torch.cat((v, prefix_v), dim=2)
if mask.numel() > 1:
mask_prefixed = torch.ones(self.prefix_len, device=mask.device, dtype=mask.dtype)
mask_prefixed = mask_prefixed.expand(*(mask.size()[:-1]), -1)
mask = torch.cat((mask, mask_prefixed), dim=-1)
return old_impl(q, k, v, mask, dropout_fn, **kw_args)
PTuningV2Mixin = PrefixTuningMixin