tinynn/graph/quantization/modules.py (127 lines of code) (raw):
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.quantization as torch_q
class QPReLU(nn.Module):
def __init__(self, prelu: nn.PReLU) -> None:
super().__init__()
# Copies weight from existing PReLU object
self.weight = torch.nn.Parameter(prelu.weight.data.detach().clone())
# Other necessary modules for QAT preparation
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.f_mul_neg_one1 = nnq.FloatFunctional()
self.f_mul_neg_one2 = nnq.FloatFunctional()
self.f_mul_alpha = nnq.FloatFunctional()
self.f_add = nnq.FloatFunctional()
self.quant = torch.quantization.QuantStub()
def forward(self, input: torch.Tensor) -> torch.Tensor:
# prelu(x) = relu(x) + weight * -relu(-x)
# in which neg is implemented by multiplying an item with negative one
# Also, we should use `add` and `mul` defined in `FloatFunctional`
# What's more, both the weights and the negative one needs to go through
# the QuantStub so as to make the whole computation graph quantized
x1 = self.relu1(input)
if self.weight.numel() == 1:
weight = self.weight.view(())
else:
weight_shape = self.weight.shape + (1,) * (input.dim() - 2)
weight = self.weight.view(*weight_shape)
weight_q = self.quant(weight)
x2 = self.f_mul_alpha.mul(
weight_q,
self.f_mul_neg_one2.mul_scalar(
self.relu2(
self.f_mul_neg_one1.mul_scalar(input, -1.0),
),
-1.0,
),
)
x = self.f_add.add(x1, x2)
return x
class QSiLU(nn.Module):
def __init__(self, _: 'nn.SiLU') -> None:
super().__init__()
self.act = nn.Sigmoid()
self.f_mul = nnq.FloatFunctional()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.f_mul.mul(input, self.act(input))
class QGLU(nn.Module):
def __init__(self, glu: nn.GLU) -> None:
super().__init__()
self.dim = glu.dim
self.f_mul = nnq.FloatFunctional()
self.sigmoid = nn.Sigmoid()
def forward(self, input: torch.Tensor) -> torch.Tensor:
slices = torch.chunk(input, 2, self.dim)
return self.f_mul.mul(slices[0], self.sigmoid(slices[1]))
class QHardsigmoid(nn.Module):
def __init__(self, hardsigmoid: nn.Hardsigmoid) -> None:
super().__init__()
self.f_mul = nnq.FloatFunctional()
self.f_add = nnq.FloatFunctional()
self.q = torch_q.QuantStub()
self.dq = torch_q.DeQuantStub()
self.act_hs = nn.Hardsigmoid()
self.act_r = nn.ReLU6()
def forward(self, input: torch.Tensor) -> torch.Tensor:
x1 = self.f_add.add_scalar(input, 3.0)
x2 = self.act_r(x1)
x3 = self.q(self.dq(x2))
return self.f_mul.mul_scalar(x3, 1 / 6)
class QLayerNorm(nn.Module):
def __init__(self, layernorm: nn.LayerNorm) -> None:
super().__init__()
self.mean_dims = tuple(range(-len(layernorm.normalized_shape), 0))
self.weight = torch.nn.Parameter(layernorm.weight.data.detach().clone())
self.bias = torch.nn.Parameter(layernorm.bias.data.detach().clone())
self.eps = layernorm.eps
self.q_rsqrt = torch_q.QuantStub()
self.q_weight = torch_q.QuantStub()
self.q_bias = torch_q.QuantStub()
self.dq_rsqrt = torch_q.DeQuantStub()
self.f_neg = nnq.FloatFunctional()
self.f_add_0 = nnq.FloatFunctional()
self.f_mul_0 = nnq.FloatFunctional()
self.f_add_1 = nnq.FloatFunctional()
self.f_mul_1 = nnq.FloatFunctional()
self.f_mul_2 = nnq.FloatFunctional()
self.f_add_2 = nnq.FloatFunctional()
def forward(self, input: torch.Tensor) -> torch.Tensor:
# LayerNorm(input) = (input - mean(input)) * rsqrt(mean(( input - mean(input) )**20) + eps ) * alpha + beta
# Currently, we completely split LayerNorm and independently count the quantization parameters
# of the intermediate activation value, which may lead to a decrease in quantization accuracy.
mean = input.mean(self.mean_dims, keepdim=True)
diff = self.f_add_0.add(input, self.f_neg.mul_scalar(mean, -1.0).expand_as(input))
squarer_difference = self.f_mul_0.mul(diff, diff)
var = squarer_difference.mean(self.mean_dims, keepdim=True)
var_eps = self.f_add_1.add_scalar(var, self.eps)
fdq_var_eps = self.dq_rsqrt(var_eps)
std_inverse = torch.rsqrt(fdq_var_eps)
q_std_inverse = self.q_rsqrt(std_inverse)
weight_fq = self.q_weight(self.weight)
bias_fq = self.q_bias(self.bias)
norm = self.f_mul_1.mul(diff, q_std_inverse)
weight_fq_expand = weight_fq.expand_as(norm)
norm_alpha = self.f_mul_2.mul(norm, weight_fq_expand)
bias_fq_expand = bias_fq.expand_as(norm_alpha)
return self.f_add_2.add(norm_alpha, bias_fq_expand)
class QRMSNorm(nn.Module):
def __init__(self, rmsnorm: 'nn.RMSNorm') -> None:
super().__init__()
self.mean_dims = tuple(range(-len(rmsnorm.normalized_shape), 0))
self.weight = torch.nn.Parameter(rmsnorm.weight.data.detach().clone())
self.eps = rmsnorm.eps
self.q_rsqrt = torch_q.QuantStub()
self.q_weight = torch_q.QuantStub()
self.dq_rsqrt = torch_q.DeQuantStub()
self.f_add_0 = nnq.FloatFunctional()
self.f_mul_0 = nnq.FloatFunctional()
self.f_mul_1 = nnq.FloatFunctional()
self.f_mul_2 = nnq.FloatFunctional()
def forward(self, input: torch.Tensor) -> torch.Tensor:
# RMSNorm(input) = (input) * rsqrt(mean(input**2)) + eps ) * alpha + beta
squard_input = self.f_mul_0.mul(input, input)
if self.eps is None:
rms_pre = squard_input.mean(self.mean_dims, keepdim=True)
else:
rms_pre = self.f_add_0.add_scalar(
squard_input.mean(self.mean_dims, keepdim=True),
self.eps,
)
fdq_rms_pre = self.dq_rsqrt(rms_pre)
rms_inverse = torch.rsqrt(fdq_rms_pre)
q_rms = self.q_rsqrt(rms_inverse)
weight_fq = self.q_weight(self.weight)
norm = self.f_mul_1.mul(input, q_rms)
weight_fq_expand = weight_fq.expand_as(norm)
return self.f_mul_2.mul(norm, weight_fq_expand)