in d2go/modeling/quantization.py [0:0]
def before_step(self):
cur_iter = self.trainer.iter
model = self.trainer.model
cfg = self.cfg
if (
not self._applied["enable_fake_quant"]
and cur_iter >= cfg.QUANTIZATION.QAT.START_ITER
):
logger.info(
"[QAT] enable fake quant to start QAT, iter = {}".format(cur_iter)
)
model.apply(torch.ao.quantization.enable_fake_quant)
model.apply(qat_utils.enable_lqat_fake_quant)
self._applied["enable_fake_quant"] = True
_reset_qat_data_loader_if_needed(
self.cfg, self.trainer, self.build_data_loader_func
)
if (
not self._applied["enable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
and cur_iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info("[QAT] enable static observer, iter = {}".format(cur_iter))
model.apply(torch.ao.quantization.enable_observer)
model.apply(qat_utils.enable_lqat_static_observer)
self._applied["enable_observer"] = True
if (
not self._applied["enable_learnable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER
):
logger.info(f"[QAT] enabling learnable observer, iter = {cur_iter}")
model.apply(qat_utils.enable_lqat_learnable_observer)
self._applied["enable_learnable_observer"] = True
if (
not self._applied["disable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info(
"[QAT] disabling observer for sub seq iters, iter = {}".format(cur_iter)
)
model.apply(torch.ao.quantization.disable_observer)
model.apply(qat_utils.disable_lqat_static_observer)
model.apply(qat_utils.disable_lqat_learnable_observer)
self._applied["disable_observer"] = True
if (
not self._applied["freeze_bn_stats"]
and cur_iter >= cfg.QUANTIZATION.QAT.FREEZE_BN_ITER
):
logger.info(
"[QAT] freezing BN for subseq iters, iter = {}".format(cur_iter)
)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
self._applied["freeze_bn_stats"] = True
if (
self._applied["enable_fake_quant"]
and cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY
and cur_iter % cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD == 0
):
logger.info(f"[QAT] updating observers, iter = {cur_iter}")
model.apply(observer_update_stat)