def __init__()

in modules/SwissArmyTransformer/sat/model/finetune/lora2.py [0:0]


    def __init__(self, original_cls, partition, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., qlora=False, original_obj=None):
        super().__init__()
        assert original_obj is not None, "original linear object must be given!"
        if lora_dropout and lora_dropout > 0:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        self.r = r
        self.lora_alpha = lora_alpha
        self.scaling = self.lora_alpha / self.r
        bias = original_obj.bias is not None
        dtype = original_obj.weight.dtype
        if qlora:
            try:
                self.original = HackLinearNF4(in_dim, out_dim, bias=bias)
            except:
                raise Exception('Build 4bit layer failed. You need to install the latest bitsandbytes. Try `pip install bitsandbytes`. If you still meet error after installation, try running `from bitsandbytes.nn import LinearNF4` with python and fix the error.')
        else:
            base_cls, kwargs = map_cls[original_cls]
            if original_cls is ColumnParallelLinear:
                kwargs['stride'] = partition
                kwargs['skip_init'] = True
                kwargs['params_dtype'] = dtype
            elif original_cls is RowParallelLinear:
                kwargs['final_bias'] = original_obj.final_bias
                kwargs['skip_init'] = True
                kwargs['params_dtype'] = dtype
            else:
                kwargs['dtype'] = dtype
            self.original = base_cls(in_dim, out_dim, **kwargs, bias=bias)
        self.original.weight.data.copy_(original_obj.weight.data.detach().clone())
        if bias:
            self.original.bias.data.copy_(original_obj.bias.data.detach().clone())
        if type(partition) is int:
            self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, original_obj.weight.shape[1]), dtype=dtype)) for _ in range(partition)])
            self.matrix_B = HackParameterList([nn.Parameter(torch.empty((original_obj.weight.shape[0] // partition, r), dtype=dtype)) for _ in range(partition)])
            for i in range(partition):
                nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5))
                nn.init.zeros_(self.matrix_B[i])
                self.matrix_B[i].model_parallel = True
                self.matrix_B[i].tensor_model_parallel = True
        else:
            new_sizes = [original_obj.weight.shape[0] // sum(partition) * i for i in partition]
            self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, original_obj.weight.shape[1]), dtype=dtype)) for _ in partition])
            self.matrix_B = HackParameterList([nn.Parameter(torch.empty((sz, r), dtype=dtype)) for sz in new_sizes])
            for i in range(len(partition)):
                nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5))
                nn.init.zeros_(self.matrix_B[i])
                self.matrix_B[i].model_parallel = True
                self.matrix_B[i].tensor_model_parallel = True
        self.partition = partition