in quant/binary/activation_quantization.py [0:0]
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
"""Forward pass of quantizing activation."""
if self.training:
# batch_vs is a 2D tensor that stores each v_i along each row
batch_vs, x_q = self._batch_quantization(x)
if self.moving_average_mode != MovingAverageMode.off:
vs_batch_avg = batch_vs.mean(1)
# Calling moving_avg_module will update its internal statistics under the hood.
# This is similar to the forward pass of batch norm.
moving_avg_vs = self.moving_avg_module(vs_batch_avg)
if self.moving_average_mode == MovingAverageMode.train_and_eval:
# If we want to use the scalars with moving average, we need to expand
# every scaling factor tensor to the batch size from a single mean element.
vs = [
moving_avg_vs[i].expand(x.shape[0])
for i in range(self.num_scaling_factors)
]
x_q = self._moving_average_quantization(x, vs)
else:
if self.moving_average_mode != MovingAverageMode.off:
# If we want to use the scalars with moving average, we need to expand
# every scaling factor tensor to the batch size from a single mean element.
vs = [
self.moving_avg_module.moving_average[i].expand(x.shape[0]) # type: ignore
for i in range(self.moving_avg_module.moving_average.size(0)) # type: ignore
]
x_q = self._moving_average_quantization(x, vs)
else:
batch_vs, x_q = self._batch_quantization(x)
return x_q