in modules/SwissArmyTransformer/sat/model/finetune/lora2.py [0:0]
def __init__(self, original_cls, partition, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., qlora=False, original_obj=None):
super().__init__()
assert original_obj is not None, "original linear object must be given!"
if lora_dropout and lora_dropout > 0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
self.r = r
self.lora_alpha = lora_alpha
self.scaling = self.lora_alpha / self.r
bias = original_obj.bias is not None
dtype = original_obj.weight.dtype
if qlora:
try:
self.original = HackLinearNF4(in_dim, out_dim, bias=bias)
except:
raise Exception('Build 4bit layer failed. You need to install the latest bitsandbytes. Try `pip install bitsandbytes`. If you still meet error after installation, try running `from bitsandbytes.nn import LinearNF4` with python and fix the error.')
else:
base_cls, kwargs = map_cls[original_cls]
if original_cls is ColumnParallelLinear:
kwargs['stride'] = partition
kwargs['skip_init'] = True
kwargs['params_dtype'] = dtype
elif original_cls is RowParallelLinear:
kwargs['final_bias'] = original_obj.final_bias
kwargs['skip_init'] = True
kwargs['params_dtype'] = dtype
else:
kwargs['dtype'] = dtype
self.original = base_cls(in_dim, out_dim, **kwargs, bias=bias)
self.original.weight.data.copy_(original_obj.weight.data.detach().clone())
if bias:
self.original.bias.data.copy_(original_obj.bias.data.detach().clone())
if type(partition) is int:
self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, original_obj.weight.shape[1]), dtype=dtype)) for _ in range(partition)])
self.matrix_B = HackParameterList([nn.Parameter(torch.empty((original_obj.weight.shape[0] // partition, r), dtype=dtype)) for _ in range(partition)])
for i in range(partition):
nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5))
nn.init.zeros_(self.matrix_B[i])
self.matrix_B[i].model_parallel = True
self.matrix_B[i].tensor_model_parallel = True
else:
new_sizes = [original_obj.weight.shape[0] // sum(partition) * i for i in partition]
self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, original_obj.weight.shape[1]), dtype=dtype)) for _ in partition])
self.matrix_B = HackParameterList([nn.Parameter(torch.empty((sz, r), dtype=dtype)) for sz in new_sizes])
for i in range(len(partition)):
nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5))
nn.init.zeros_(self.matrix_B[i])
self.matrix_B[i].model_parallel = True
self.matrix_B[i].tensor_model_parallel = True
self.partition = partition