chatlearn/models/megatron/lora/initializer.py (26 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """distributed initializer.""" import math import warnings import torch from torch import Tensor from torch.nn.init import calculate_gain def distributed_kaiming_uniform_( tensor: Tensor, fan: int, a: float = 0, nonlinearity: str = 'leaky_relu' ): r"""Fills the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a uniform distribution. The resulting tensor will have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where .. math:: \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} Also known as He initialization. Args: tensor: an n-dimensional `torch.Tensor` fan: fan_in or fan_out for W. a: the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``) nonlinearity: the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). Examples: >>> w = torch.empty(3, 5) >>> distributed_kaiming_uniform_(w, fan=3, nonlinearity='relu') """ if torch.overrides.has_torch_function_variadic(tensor): return torch.overrides.handle_torch_function( distributed_kaiming_uniform_, (tensor, fan), tensor=tensor, fan=fan, a=a, nonlinearity=nonlinearity) if 0 in tensor.shape: warnings.warn("Initializing zero-element tensors is a no-op") return tensor assert isinstance(fan, int) and fan > 0, "Expect fan int and greater than 0." gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation with torch.no_grad(): return tensor.uniform_(-bound, bound)