maga_transformer/utils/smooth_quant_convert/llama/convert.py (191 lines of code) (raw):

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utilities for exporting a model to our custom format. """ import numpy as np import torch def save_val(val, dir, key, tp_num=None): dir[f"model.{key}"] = torch.from_numpy(np.asarray(val)) def save_split(split_vals, dir, key, i, factor): for j, val in enumerate(split_vals): save_val(val, dir, key, i * factor + j) def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): """ This function has two purposes: - compute quantized weights, scaled either per-tensor or per-column - compute scaling factors Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ. CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W. CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor. Here is the list of what we need (T means per-tensor, C per-column): - scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T) # - scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T) - scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C) # - scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32) # to quant range (int8) (used for CUBLAS) (T, C) Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too, but then the model would change depending on the number of GPUs used. For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V. For our GEMM implementation to respect this behavior, we use per-column mode and replicate values along columns. """ # compute weight scaling factors for fp->int8 and int8->fp if is_qkv and not multi_query_mode: scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max( dim=-1, keepdims=True)[0].cpu().numpy() scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, -1).cpu().numpy() elif is_qkv and multi_query_mode: hidden_dim = weights.shape[0] local_dim = act_range["w"].shape[0] kv_dim = (local_dim - hidden_dim) // 2 scale_w_q = act_range["w"][0:hidden_dim] scale_w_k = act_range["w"][hidden_dim:hidden_dim + kv_dim] scale_w_v = act_range["w"][-kv_dim:] scale_w_qkv_t = torch.concat([ scale_w_q.max(dim=0, keepdim=True)[0], scale_w_k.max(dim=0, keepdim=True)[0], scale_w_v.max(dim=0, keepdim=True)[0] ]) scale_w_orig_quant_t = 127. / scale_w_qkv_t.cpu().numpy() scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() else: scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy() scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c # compute the rest of needed scaling factors scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item()) scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item()) scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.) scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t * scale_w_orig_quant_t) scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t * scale_w_orig_quant_c) if is_qkv and not multi_query_mode: scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t, scale_w_orig_quant_c.shape) scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t, scale_w_orig_quant_c.shape) if is_qkv and multi_query_mode: scale_q_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[0], scale_w_q.shape) scale_k_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[1], scale_w_k.shape) scale_v_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[2], scale_w_v.shape) scale_y_accum_quant_t = np.concatenate( [scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t]) scale_w_quant_orig_t = np.concatenate([ np.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape), np.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape), np.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape) ]) to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8) if is_qkv and multi_query_mode: scale_w_quant_orig_t_expand = np.ones([weights.shape[-1]]) scale_w_quant_orig_t_expand[:hidden_dim] = scale_w_quant_orig_t[0] scale_w_quant_orig_t_expand[hidden_dim:hidden_dim + kv_dim] = scale_w_quant_orig_t[1] scale_w_quant_orig_t_expand[-kv_dim:] = scale_w_quant_orig_t[2] weight_int8 = to_i8(weights * scale_w_quant_orig_t_expand) else: weight_int8 = to_i8(weights * scale_w_orig_quant_t) return { "weight.int8": weight_int8, "weight.int8.col": to_i8(weights * scale_w_orig_quant_c), "scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32), "scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32), "scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32), "scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32), "scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32), "scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32), } def save_multi_query_mode_qkv_int8(val, dir, base_key, saved_key, factor, rank, local_dim, head_size): q, k, v = np.split(val, [local_dim, local_dim + head_size], axis=-1) q_split = np.split(q, factor, axis=-1) k_split = np.split(k, factor, axis=-1) v_split = np.split(v, factor, axis=-1) split_vals = [ np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) for ii in range(factor) ] save_split(split_vals, dir, f"{base_key}.{saved_key}", rank, factor) def write_int8(vals, dir, base_key, split_dim, i, factor, is_qkv=False, multi_query_mode=False): saved_keys_once = [ "scale_x_orig_quant", "scale_w_quant_orig" ] # saved_keys_once = [ # "scale_x_orig_quant", "scale_w_quant_orig", "scale_y_accum_quant", # "scale_y_quant_orig" # ] # if is_qkv and multi_query_mode: # assert split_dim == -1 # local_dim = vals["weight.int8"].shape[0] # head_size = (vals["weight.int8"].shape[1] - local_dim) // 2 # # save_multi_query_mode_qkv_int8(vals["weight.int8"], dir, base_key, # # "weight.int8", factor, i, local_dim, # # head_size) # save_multi_query_mode_qkv_int8(vals["weight.int8.col"], dir, base_key, # "weight.int8.col", factor, i, local_dim, # head_size) # save_multi_query_mode_qkv_int8(vals["scale_w_quant_orig.col"], dir, # base_key, "scale_w_quant_orig.col", # factor, i, local_dim, head_size) # # save_multi_query_mode_qkv_int8(vals["scale_y_accum_quant.col"], dir, # # base_key, "scale_y_accum_quant.col", # # factor, i, local_dim, head_size) # # save_multi_query_mode_qkv_int8(vals["scale_w_quant_orig"], dir, # # base_key, "scale_w_quant_orig", factor, # # i, local_dim, head_size) # # save_multi_query_mode_qkv_int8(vals["scale_y_accum_quant"], dir, # # base_key, "scale_y_accum_quant", factor, # # i, local_dim, head_size) # saved_keys_once = ["scale_x_orig_quant", "scale_y_quant_orig"] # else: # save_split(np.split(vals["weight.int8"], factor, axis=split_dim), dir, # f"{base_key}.weight.int8", i, factor) save_split(np.split(vals["weight.int8.col"], factor, axis=split_dim), dir, f"{base_key}.weight.int8.col", i, factor) if split_dim == -1: save_split( np.split(vals["scale_w_quant_orig.col"], factor, axis=split_dim), dir, f"{base_key}.scale_w_quant_orig.col", i, factor) # save_split( # np.split(vals["scale_y_accum_quant.col"], # factor, # axis=split_dim), dir, # f"{base_key}.scale_y_accum_quant.col", i, factor) # if is_qkv: # save_split( # np.split(vals["scale_y_accum_quant"], # factor, # axis=split_dim), dir, # f"{base_key}.scale_y_accum_quant", i, factor) # save_split( # np.split(vals["scale_w_quant_orig"], factor, # axis=split_dim), dir, # f"{base_key}.scale_w_quant_orig", i, factor) # saved_keys_once = ["scale_x_orig_quant", "scale_y_quant_orig"] saved_keys_once = ["scale_x_orig_quant"] else: saved_keys_once += [ # "scale_w_quant_orig.col", "scale_y_accum_quant.col" "scale_w_quant_orig.col" ] if i == 0: for save_key in saved_keys_once: save_val(vals[save_key], dir, f"{base_key}.{save_key}") def str_to_np_dtype(type_str): convert_dict = { "fp32": np.float32, "fp16": np.float16, } dtype = convert_dict.get(type_str) if dtype is None: raise ValueError(f"{type_str} is an invalid storage type") return dtype def split_and_save_weight(i, saved_dir, factor, key, val, act_range, config): saved_dir = {} # The split_factor indicates the number of ranks to implement # distributed GEMMs. For Tensor Parallelism, each rank/GPU works # on split_hidden_dim // split_factor channels. int8_outputs = config.get("int8_outputs", None) multi_query_mode = config.get("multi_query_mode", False) local_dim = config.get("local_dim", None) save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" if "input_layernorm.weight" in key or "input_layernorm.bias" in key or \ "attention.dense.bias" in key or "post_layernorm.weight" in key or \ "post_attention_layernorm.bias" in key or "mlp.dense_4h_to_h.bias" in key or \ "final_layernorm.weight" in key or "final_layernorm.bias" in key: # shared weights, only need to convert the weights of rank 0 if i == 0: save_val(val, saved_dir, key) elif "attention.dense.weight" in key or "mlp.proj.weight" in key: split_dim = 0 split_vals = np.split(val, factor, axis=split_dim) # save_split(split_vals, saved_dir, key, i, factor) if act_range is not None and int8_outputs == "all": base_key = key.replace(".weight", "") vals_i8 = generate_int8(val, act_range) write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor) elif "mlp.fc.weight" in key or "mlp.gate.weight" in key: split_dim = -1 split_vals = np.split(val, factor, axis=split_dim) # save_split(split_vals, saved_dir, key, i, factor) if act_range is not None and int8_outputs == "all": base_key = key.replace(".weight", "") vals_i8 = generate_int8(val, act_range) write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor) elif "attention.query_key_value.weight" in key: hidden_dim = val.shape[0] if local_dim is None: local_dim = val.shape[-1] // 3 if multi_query_mode: head_size = (val.shape[-1] - local_dim) // 2 val = val.reshape(hidden_dim, local_dim + 2 * head_size) w_q, w_k, w_v = np.split(val, [local_dim, local_dim + head_size], axis=-1) w_q_split = np.split(w_q, factor, axis=-1) w_k_split = np.split(w_k, factor, axis=-1) w_v_split = np.split(w_v, factor, axis=-1) split_vals = [ np.concatenate((w_q_split[ii], w_k_split[ii], w_v_split[ii]), axis=-1) for ii in range(factor) ] split_dim = -1 else: val = val.reshape(hidden_dim, 3, local_dim) split_dim = -1 split_vals = np.split(val, factor, axis=split_dim) # save_split(split_vals, saved_dir, key, i, factor) if save_int8: base_key = key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, is_qkv=True, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor, is_qkv=True, multi_query_mode=multi_query_mode) elif "attention.dense.smoother" in key or "mlp.proj.smoother" in key: split_vals = np.split(val, factor, axis=0) save_split(split_vals, saved_dir, key, i, factor) else: print(f"[WARNING] {key} not handled by converter") return saved_dir