maga_transformer/device/device_impl.py (305 lines of code) (raw):

import re from maga_transformer.device.device_base import DeviceBase, DeviceType, MemInfo from maga_transformer.ops import DeviceType, DeviceExporter from maga_transformer.utils.model_weight import W import torch import psutil import os import logging class CpuImpl(DeviceBase): def __init__(self, exported_device: DeviceExporter): super().__init__(exported_device) def _get_mem_info(self) -> MemInfo: vmem = psutil.virtual_memory() return MemInfo(vmem.used, vmem.free) class ArmCpuImpl(CpuImpl): def __init__(self, exported_device: DeviceExporter): super().__init__(exported_device) self.gemm_rewrite_list = [ W.attn_qkv_w, W.attn_o_w, W.ffn_w1, W.ffn_w2, W.ffn_w3, ] def maybe_rewrite_weight_by_key(self, key: str, weight: torch.Tensor) -> torch.Tensor: return self.exported_device.preprocess_gemm_weight_by_key(key, weight) def unpack_int32_into_int16(self, w_packed: torch.Tensor, int8: bool): if int8: return w_packed.contiguous().view(torch.uint8).to(torch.int16) # unpack inputs packed in int32/float32 into uint4 and store them in int8 format w_packed_int4x2 = w_packed.contiguous().view(torch.uint8) w_unpacked = torch.zeros(w_packed_int4x2.shape[0], w_packed_int4x2.shape[1] * 2, dtype=torch.int8) w_unpacked[:, ::2] = w_packed_int4x2 % 16 w_unpacked[:, 1::2] = w_packed_int4x2 // 16 return w_unpacked.to(torch.int16).contiguous() def preprocess_groupwise_weight_params(self, qweight_int32, qzeros_int32, scales_fp16, device: str, gptq: bool, awq: bool, weight_bits: int): GPTQ_FLAG = 1 if gptq == True else 0 qweight = qweight_int32.reshape(qweight_int32.shape[0], -1).cpu() qzeros = qzeros_int32.reshape(qzeros_int32.shape[0], -1).cpu() scales_fp16 = scales_fp16.reshape(scales_fp16.shape[0], -1).cpu() packer = self.exported_device.pack_int8_tensor_to_packed_int4 preprocess_weight_scale = self.exported_device.preprocess_weight_scale is_int8 = weight_bits == 8 if is_int8: zero_shift = 128 quant_type = torch.int8 else: zero_shift = 8 quant_type = torch.quint4x2 if awq: qweight = self.unpack_int32_into_int16(qweight, is_int8).contiguous() - zero_shift qweight = self.reverse_awq_order(qweight) elif gptq: qweight = self.unpack_int32_into_int16(qweight.T, is_int8).T.contiguous() - zero_shift qweight = qweight.to(torch.int8) if not is_int8: qweight = packer(qweight) qweight_interleaved = preprocess_weight_scale(qweight, scales_fp16) # zero = 0 if qzeros_int32 = -2004318072 torch.int32 for awq # zero = 0 if qzeros_int32 = 2004318071 torch.int32 for gptq qzeros = self.unpack_int32_into_int16(qzeros, is_int8) if awq: qzeros = self.reverse_awq_order(qzeros) # zeros = zeros * scales UINT_TO_INT_FLAG = 1 zeros_x_scales_fp16 = (-qzeros + zero_shift * UINT_TO_INT_FLAG - GPTQ_FLAG) * scales_fp16 zeros_x_scales_fp16 = zeros_x_scales_fp16.half() # return processed interleaved weight, original scales and zeros * scales return qweight_interleaved.contiguous().to(device), zeros_x_scales_fp16.contiguous().to(device), scales_fp16.contiguous().to(device) class GpuImpl(DeviceBase): def __init__(self, exported_device: DeviceExporter): super().__init__(exported_device) def get_device_id(self) -> int: return torch.cuda.current_device() def unpack_int32_into_int16(self, w_packed: torch.Tensor, int8: bool): if int8: return w_packed.contiguous().view(torch.uint8).to(torch.int16) # unpack inputs packed in int32/float32 into uint4 and store them in int8 format w_packed_int4x2 = w_packed.contiguous().view(torch.uint8) w_unpacked = torch.zeros(w_packed_int4x2.shape[0], w_packed_int4x2.shape[1] * 2, dtype=torch.int8) w_unpacked[:, ::2] = w_packed_int4x2 % 16 w_unpacked[:, 1::2] = w_packed_int4x2 // 16 return w_unpacked.to(torch.int16).contiguous() def reverse_awq_order(self, ori_tensor: torch.Tensor): # AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] assert ori_tensor.shape[-1] % 8 == 0 reorder_tensor = ori_tensor.reshape(-1, 2,4).transpose(2,1).reshape(ori_tensor.shape) return reorder_tensor @property def specify_gpu_arch(self): return os.environ.get('SPECIFY_GPU_ARCH', "") def apply_int8(self, tensor: torch.Tensor, device: str): shape = tensor.shape int8_weight, int8_scale = self.exported_device.symmetric_quantize_last_axis_of_batched_matrix( # type: ignore tensor.reshape([shape[0], -1]).cpu(), torch.int8, self.specify_gpu_arch) int8_weight = int8_weight.reshape(shape) return int8_weight.to(device), int8_scale.to(device) def moe_apply_int8(self, tensor: torch.Tensor, device: str): assert tensor.dim() == 3 tensor_list = torch.chunk(tensor, tensor.shape[0], dim=0) int8_weights = [] int8_scales = [] for t in tensor_list: t = torch.squeeze(t).transpose(1,0).contiguous() shape = t.shape weight, scale = self.exported_device.symmetric_quantize_last_axis_of_batched_matrix( # type: ignore t.reshape([shape[0], -1]).cpu(), torch.int8, self.specify_gpu_arch) int8_weights.append(weight) int8_scales.append(scale) int8_weight = torch.stack(int8_weights, dim=0) int8_scale = torch.stack(int8_scales, dim=0) return int8_weight.to(device), int8_scale.to(device) def preprocess_groupwise_weight_params(self, qweight_int32, qzeros_int32, scales_fp16, device: str, gptq: bool, awq: bool, weight_bits: int): GPTQ_FLAG = 1 if gptq == True else 0 qweight = qweight_int32.reshape(qweight_int32.shape[0], -1).cpu() qzeros = qzeros_int32.reshape(qzeros_int32.shape[0], -1).cpu() scales_fp16 = scales_fp16.reshape(scales_fp16.shape[0], -1).cpu() packer = self.exported_device.pack_int8_tensor_to_packed_int4 preprocessor = self.exported_device.preprocess_weights_for_mixed_gemm is_int8 = weight_bits == 8 if is_int8: zero_shift = 128 quant_type = torch.int8 else: zero_shift = 8 quant_type = torch.quint4x2 if awq: qweight = self.unpack_int32_into_int16(qweight, is_int8).contiguous() - zero_shift qweight = self.reverse_awq_order(qweight) elif gptq: qweight = self.unpack_int32_into_int16(qweight.T, is_int8).T.contiguous() - zero_shift qweight = qweight.to(torch.int8) if not is_int8: qweight = packer(qweight) qweight_interleaved = preprocessor(qweight, quant_type, self.specify_gpu_arch) # zero = 0 if qzeros_int32 = -2004318072 torch.int32 for awq # zero = 0 if qzeros_int32 = 2004318071 torch.int32 for gptq qzeros = self.unpack_int32_into_int16(qzeros, is_int8) if awq: qzeros = self.reverse_awq_order(qzeros) # zeros = zeros * scales UINT_TO_INT_FLAG = 1 zeros_x_scales_fp16 = (-qzeros + zero_shift * UINT_TO_INT_FLAG - GPTQ_FLAG) * scales_fp16 zeros_x_scales_fp16 = zeros_x_scales_fp16.half() # return processed interleaved weight, original scales and zeros * scales return qweight_interleaved.contiguous().to(device), zeros_x_scales_fp16.contiguous().to(device), scales_fp16.contiguous().to(device) def preprocess_moe_groupwise_weight_params(self, qweight_int32, qzeros_int32, scales_fp16, device: str, gptq: bool, awq: bool, weight_bits: int): assert qweight_int32.dim() == 3 qweight_list = torch.chunk(qweight_int32, qweight_int32.shape[0], dim=0) qzeros_list = torch.chunk(qzeros_int32, qzeros_int32.shape[0], dim=0) scales_list = torch.chunk(scales_fp16, scales_fp16.shape[0], dim=0) processed_weights = [] processed_zeros = [] processed_scalses = [] for w, z, s in zip(qweight_list, qzeros_list, scales_list): w = torch.squeeze(w).transpose(1, 0).contiguous() z = torch.squeeze(z).transpose(1, 0).contiguous() s = torch.squeeze(s).transpose(1, 0).contiguous() p_w, p_z, p_s = self.preprocess_groupwise_weight_params(w, z, s, device, gptq, awq, weight_bits) processed_weights.append(p_w) processed_zeros.append(p_z) processed_scalses.append(p_s) processed_weights = torch.stack(processed_weights, dim=0) processed_zeros = torch.stack(processed_zeros, dim=0) processed_scalses = torch.stack(processed_scalses, dim=0) return processed_weights, processed_zeros, processed_scalses def shuffle_moe_weight(self, x: torch.Tensor, datatype: torch.dtype, name: str) -> torch.Tensor: return x class CudaImpl(GpuImpl): def __init__(self, exported_device: DeviceExporter): super().__init__(exported_device) try: import pynvml pynvml.nvmlInit() except Exception as e: logging.warn(f"no nvml found: " + str(e)) def _get_mem_info(self) -> MemInfo: import pynvml handle = pynvml.nvmlDeviceGetHandleByIndex(torch.cuda._parse_visible_devices()[0]) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) return MemInfo(meminfo.used, meminfo.free) @property def arch(self) -> int: try: device = self.get_device_id() major, minor = torch.cuda.get_device_capability(device) arch = major * 10 + minor return arch except Exception as e: logging.warn(f"Cannot get CUDA device capability: {e}") return super().arch # 使用父类的实现 @property def support_dio_load(self) -> bool: return True class PpuImpl(CudaImpl): @property def support_dio_load(self) -> bool: return False class RocmImpl(GpuImpl): def __init__(self, exported_device: DeviceExporter): super().__init__(exported_device) try: from pyrsmi import rocml rocml.smi_initialize() except Exception as e: logging.warn(f"no rocm smi found: " + str(e)) def _get_mem_info(self) -> MemInfo: from pyrsmi import rocml id = self.get_device_id() used = rocml.smi_get_device_memory_used(id) total = rocml.smi_get_device_memory_total(id) return MemInfo(total - used, used) @property def arch(self) -> str: if self.rocml: try: id = self.get_device_id() device_name = self.rocml.smi_get_device_name(id) # 从设备名称中提取架构信息(假设名称包含 gfx 版本) gfx_match = re.search(r'gfx(\d+)', device_name) if gfx_match: return gfx_match.group(1) except Exception as e: logging.warn(f"Cannot get ROCm device gfx version: {e}") # 如果无法获取,则使用环境变量或默认值 return os.environ.get('SPECIFY_GPU_ARCH', "900") def preprocess_groupwise_weight_params(self, qweight_int32, qzeros_int32, scales_fp16, device: str, gptq: bool, awq: bool, weight_bits: int): GPTQ_FLAG = 1 if gptq == True else 0 qweight = qweight_int32.reshape(qweight_int32.shape[0], -1).cpu() qzeros = qzeros_int32.reshape(qzeros_int32.shape[0], -1).cpu() scales_fp16 = scales_fp16.reshape(scales_fp16.shape[0], -1).cpu() packer = self.exported_device.pack_int8_tensor_to_packed_int4 preprocessor = self.exported_device.preprocess_weights_for_mixed_gemm is_int8 = weight_bits == 8 if is_int8: zero_shift = 128 quant_type = torch.int8 else: zero_shift = 8 quant_type = torch.quint4x2 if awq: qweight = self.unpack_int32_into_int16(qweight, is_int8).contiguous() - zero_shift qweight = self.reverse_awq_order(qweight) elif gptq: qweight = self.unpack_int32_into_int16(qweight.T, is_int8).T.contiguous() - zero_shift qweight = qweight.to(torch.int8) if not is_int8: qweight = packer(qweight) qweight_interleaved = preprocessor(qweight, quant_type, self.specify_gpu_arch) # zero = 0 if qzeros_int32 = -2004318072 torch.int32 for awq # zero = 0 if qzeros_int32 = 2004318071 torch.int32 for gptq qzeros = self.unpack_int32_into_int16(qzeros, is_int8) if awq: qzeros = self.reverse_awq_order(qzeros) # zeros = zeros * scales UINT_TO_INT_FLAG = 1 zeros_x_scales_fp16 = (-qzeros + zero_shift * UINT_TO_INT_FLAG - GPTQ_FLAG) * scales_fp16 zeros_x_scales_fp16 = zeros_x_scales_fp16.half() ########################################################### # scales row major -> scales column major layout to match CK kernel layout # TODO: need add device infomation for selection scales_fp16_t = scales_fp16.transpose(0, 1).contiguous() scales_fp16 = scales_fp16_t.transpose(1, 0).cpu() # zeros_x_scales row major -> zeros_x_scales column major layout to match CK kernel layout zeros_x_scales_fp16_t = zeros_x_scales_fp16.transpose(0, 1).contiguous() zeros_x_scales_fp16 = zeros_x_scales_fp16_t.transpose(1, 0).cpu() ########################################################### # return processed interleaved weight, original scales and zeros * scales # return qweight_interleaved.contiguous().to(device), zeros_x_scales_fp16.contiguous().to(device), scales_fp16.contiguous().to(device) # kernel, scales, zeros all need for column major layout return qweight_interleaved.to(device), zeros_x_scales_fp16.to(device), scales_fp16.to(device) @property def arch(self) -> str: if self.rocml: try: id = self.get_device_id() device_name = self.rocml.smi_get_device_name(id) # 从设备名称中提取架构信息(假设名称包含 gfx 版本) gfx_match = re.search(r'gfx(\d+)', device_name) if gfx_match: return gfx_match.group(1) except Exception as e: logging.warn(f"Cannot get ROCm device gfx version: {e}") # 如果无法获取,则使用环境变量或默认值 return os.environ.get('SPECIFY_GPU_ARCH', "900") def shuffle_moe_weight(self, x: torch.Tensor, datatype: torch.dtype, name: str) -> torch.Tensor: is_gate = name == W.moe_w1 align = [0, 512, 0] if is_gate else [0, 0, 512] if len(align) != len(x.shape): logging.error(f'Data type for moe weight is not supported: {datatype}') return x x_ = torch.cat([x[:, x.shape[1] // 2:, :], x[:, :x.shape[1]//2, :]], dim =1) if is_gate else x #swap from [up, gate] to [gate, up] shape_tmp = list(x_.shape) #due to gate+up, need temporarily seperate them for padding if (is_gate): shape_tmp[1] = shape_tmp[1] // 2 #align and padding padding = [0 for i in range(len(align)*2)] for i in range(len(align)): if (align[i] > 0) and (shape_tmp[i] % align[i] > 0): padding[-(i*2+1)] = align[i] - (shape_tmp[i] % align[i]) if sum(padding): if (is_gate): x_ = torch.cat( [torch.nn.functional.pad(x_[:, :x_.shape[1] // 2, :], padding, mode='constant', value=0), torch.nn.functional.pad(x_[:, x_.shape[1] // 2:, :], padding, mode='constant', value=0)], dim = 1) else: x_ = torch.nn.functional.pad(x_, tuple(padding), mode='constant', value=0) # logging.info(f'Moe padding shape {[ele for ele in x.shape]} with {padding} to {[ele for ele in x_.shape]}') b_: int = x_.shape[0] n_: int = x_.shape[1] k_: int = x_.shape[2] if (datatype==torch.float16) or (datatype==torch.bfloat16): x_ = x_.view(b_, n_ // 16, 16, k_ // 32, 4, 8) elif (datatype==torch.float8) or (datatype==torch.int8): x_ = x_.view(b_, n_ // 16, 16, k_ // 64, 4, 16) else: logging.error(f'Data type for moe weight is not supported: {datatype}') return x x_ = x_.permute(0, 1, 3, 4, 2, 5).contiguous() x_ = x_.view(b_, n_ , k_) x_ = x_.contiguous() return x_