in quant/utils/moving_average.py [0:0]
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
"""Return the current moving average, given a vector x."""
if self.training:
with torch.no_grad():
if self.num_batches_tracked.item() > 0: # type: ignore
old = self.momentum * self.moving_average # type: ignore
new = (torch.ones_like(self.momentum) - self.momentum) * x # type: ignore
self.moving_average.copy_(old + new) # type: ignore
else:
self.moving_average.copy_(x) # type: ignore
self.num_batches_tracked += 1 # type: ignore
return self.moving_average # type: ignore