maga_transformer/utils/smooth_quant_convert/llama/smoothquant.py (150 lines of code) (raw):

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ''' Utilities for SmoothQuant models ''' import copy import functools from collections import defaultdict import torch import torch.nn as nn from tqdm import tqdm from transformers.pytorch_utils import Conv1D @torch.no_grad() def apply_smoothing(scales, gemm_weights, layernorm_weights=None, layernorm_bias=None, dtype=torch.float32, layernorm_1p=False): if not isinstance(gemm_weights, list): gemm_weights = [gemm_weights] if layernorm_weights is not None: assert layernorm_weights.numel() == scales.numel() layernorm_weights.div_(scales).to(dtype) if layernorm_bias is not None: assert layernorm_bias.numel() == scales.numel() layernorm_bias.div_(scales).to(dtype) if layernorm_1p: layernorm_weights += (1 / scales) - 1 for gemm in gemm_weights: gemm.mul_(scales.view(1, -1)).to(dtype) @torch.no_grad() def smooth_gemm(gemm_weights, act_scales, layernorm_weights=None, layernorm_bias=None, alpha=0.5, weight_scales=None): if not isinstance(gemm_weights, list): gemm_weights = [gemm_weights] orig_dtype = gemm_weights[0].dtype for gemm in gemm_weights: # gemm_weights are expected to be transposed assert gemm.shape[1] == act_scales.numel() if weight_scales is None: weight_scales = torch.cat( [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], dim=0) weight_scales = weight_scales.max(dim=0)[0] weight_scales.to(float).clamp(min=1e-5) scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5) apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias, orig_dtype) return scales @torch.no_grad() def smooth_gemm_fc1_gate(fc1_weights, gate_weights, act_scales, layernorm_weights=None, layernorm_bias=None, alpha=0.5, weight_scales=None): gemm_weights = [] if not isinstance(fc1_weights, list): fc1_weights = [fc1_weights] if not isinstance(gate_weights, list): gate_weights = [gate_weights] for i in range(len(fc1_weights)): gemm_weight = torch.cat([fc1_weights[i], gate_weights[i]], dim=0) gemm_weights.append(gemm_weight) orig_dtype = gemm_weights[0].dtype for gemm in gemm_weights: # gemm_weights are expected to be transposed assert gemm.shape[1] == act_scales.numel() if weight_scales is None: weight_scales = torch.cat( [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], dim=0) weight_scales = weight_scales.max(dim=0)[0] weight_scales.to(float).clamp(min=1e-5) scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5) apply_smoothing(scales, fc1_weights + gate_weights, layernorm_weights, layernorm_bias, orig_dtype) return scales @torch.no_grad() def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): if not isinstance(fcs, list): fcs = [fcs] for fc in fcs: assert isinstance(fc, nn.Linear) assert ln.weight.numel() == fc.in_features == act_scales.numel() device, dtype = fcs[0].weight.device, fcs[0].weight.dtype act_scales = act_scales.to(device=device, dtype=dtype) weight_scales = torch.cat( [fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) if ln is not None: ln.weight.div_(scales) ln.bias.div_(scales) for fc in fcs: fc.weight.mul_(scales.view(1, -1)) return scales @torch.no_grad() def capture_activation_range(model, tokenizer, dataset, num_samples=512, seq_len=512): model.eval() device = next(model.parameters()).device act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None}) tokenizer.pad_token = tokenizer.eos_token def stat_tensor(name, tensor, act_scales, key): hidden_dim = tensor.shape[-1] tensor = tensor.view(-1, hidden_dim).abs().detach() comming_max = torch.max(tensor, dim=0)[0].float() if act_scales[name][key] is None: act_scales[name][key] = comming_max else: act_scales[name][key] = torch.max(act_scales[name][key], comming_max) def stat_input_hook(m, x, y, name): if isinstance(x, tuple): x = x[0] stat_tensor(name, x, act_scales, "x") stat_tensor(name, y, act_scales, "y") if act_scales[name]["w"] is None: act_scales[name]["w"] = m.weight.abs().clip(1e-8, None).max(dim=1)[0] hooks = [] for name, m in model.named_modules(): if isinstance(m, nn.Linear) or isinstance(m, Conv1D): hooks.append( m.register_forward_hook( functools.partial(stat_input_hook, name=name))) for i in tqdm(range(num_samples), desc="calibrating model"): datapoint = dataset['train'][i:i + 1] line = copy.copy(datapoint['article']) line[0] = line[0] + ' TL;DR: ' line[0] = line[0].strip() line[0] = line[0].replace(" n't", "n't") input_ids = tokenizer(line, return_tensors="pt", max_length=seq_len, padding=True, truncation=True).input_ids.to(device) model(input_ids) for h in hooks: h.remove() return act_scales