def before_step()

in d2go/modeling/quantization.py [0:0]


    def before_step(self):
        cur_iter = self.trainer.iter
        model = self.trainer.model
        cfg = self.cfg

        if (
            not self._applied["enable_fake_quant"]
            and cur_iter >= cfg.QUANTIZATION.QAT.START_ITER
        ):
            logger.info(
                "[QAT] enable fake quant to start QAT, iter = {}".format(cur_iter)
            )
            model.apply(torch.ao.quantization.enable_fake_quant)
            model.apply(qat_utils.enable_lqat_fake_quant)
            self._applied["enable_fake_quant"] = True

            _reset_qat_data_loader_if_needed(
                self.cfg, self.trainer, self.build_data_loader_func
            )

        if (
            not self._applied["enable_observer"]
            and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
            and cur_iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
        ):
            logger.info("[QAT] enable static observer, iter = {}".format(cur_iter))
            model.apply(torch.ao.quantization.enable_observer)
            model.apply(qat_utils.enable_lqat_static_observer)
            self._applied["enable_observer"] = True

        if (
            not self._applied["enable_learnable_observer"]
            and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER
        ):
            logger.info(f"[QAT] enabling learnable observer, iter = {cur_iter}")
            model.apply(qat_utils.enable_lqat_learnable_observer)
            self._applied["enable_learnable_observer"] = True

        if (
            not self._applied["disable_observer"]
            and cur_iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
        ):
            logger.info(
                "[QAT] disabling observer for sub seq iters, iter = {}".format(cur_iter)
            )
            model.apply(torch.ao.quantization.disable_observer)
            model.apply(qat_utils.disable_lqat_static_observer)
            model.apply(qat_utils.disable_lqat_learnable_observer)
            self._applied["disable_observer"] = True

        if (
            not self._applied["freeze_bn_stats"]
            and cur_iter >= cfg.QUANTIZATION.QAT.FREEZE_BN_ITER
        ):
            logger.info(
                "[QAT] freezing BN for subseq iters, iter = {}".format(cur_iter)
            )
            model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
            self._applied["freeze_bn_stats"] = True

        if (
            self._applied["enable_fake_quant"]
            and cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY
            and cur_iter % cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD == 0
        ):
            logger.info(f"[QAT] updating observers, iter = {cur_iter}")
            model.apply(observer_update_stat)