in modules/SwissArmyTransformer/sat/model/finetune/lora2.py [0:0]
def reinit(self, parent_model):
for i in self.layer_range:
print_rank0(f'replacing layer {i} attention with lora')
parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.hidden_size, out_size=None)
parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora(parent_model.transformer.layers[i].attention.query_key_value, parent_model.transformer.layers[i].attention.stride, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.hidden_size, out_size=None if not parent_model.transformer.num_multi_query_heads else parent_model.transformer.layers[i].attention.inner_hidden_size + parent_model.transformer.layers[i].attention.hidden_size_per_attention_head * parent_model.transformer.layers[i].attention.num_multi_query_heads * 2)
if self.cross_attention and parent_model.transformer.layers[i].is_decoder:
print_rank0(f'replacing layer {i} cross attention with lora')
kv_size = parent_model.transformer.layers[i].cross_attention.inner_hidden_size * 2 if not parent_model.transformer.cross_num_multi_query_heads else parent_model.transformer.layers[i].cross_attention.hidden_size_per_attention_head * parent_model.transformer.layers[i].cross_attention.cross_num_multi_query_heads * 2
parent_model.transformer.layers[i].cross_attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.layers[i].cross_attention.inner_hidden_size, out_size=parent_model.transformer.hidden_size)
parent_model.transformer.layers[i].cross_attention.query = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.query, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.hidden_size, out_size=parent_model.transformer.layers[i].cross_attention.inner_hidden_size)
parent_model.transformer.layers[i].cross_attention.key_value = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.key_value, 2, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.layers[i].cross_attention.cross_attn_hidden_size, out_size=kv_size)
if self.qlora:
print_rank0('replacing chatglm linear layer with 4bit')
def replace_linear_with_nf4(model, name=None, cache={}):
if type(model) in (nn.Linear, RowParallelLinear, ColumnParallelLinear):
out_dim, in_dim = model.weight.shape
bias = model.bias is not None
new_linear = HackLinearNF4(in_dim, out_dim, bias=bias)
new_linear.weight.data.copy_(model.weight.data.detach().clone())
if bias:
new_linear.bias.data.copy_(model.bias.data.detach().clone())
return new_linear
names = set()
for name, child in model.named_children():
if name not in names:
if child in cache:
new_child = cache[child]
else:
new_child = replace_linear_with_nf4(child, name=name, cache=cache)
cache[child] = new_child
setattr(model, name, new_child)
names.add(name)
flag = True
while flag:
flag = False
for name, child in model.named_children():
if name not in names:
setattr(model, name, cache[child])
names.add(name)
flag = True
return model
replace_linear_with_nf4(parent_model.transformer, None, {})