def post_training_quantize()

in d2go/modeling/quantization.py [0:0]


def post_training_quantize(cfg, model, data_loader):
    """Calibrate a model, convert it to a quantized pytorch model"""
    model = copy.deepcopy(model)
    model.eval()
    # TODO: check why some parameters will have gradient
    for param in model.parameters():
        param.requires_grad = False

    if hasattr(model, "prepare_for_quant"):
        model = model.prepare_for_quant(cfg)
    else:
        logger.info("Using default implementation for prepare_for_quant")
        model = default_prepare_for_quant(cfg, model)

    if cfg.QUANTIZATION.EAGER_MODE:
        torch.ao.quantization.prepare(model, inplace=True)
    logger.info("Prepared the PTQ model for calibration:\n{}".format(model))

    # Option for forcing running calibration on GPU, works only when the model supports
    # casting both model and inputs.
    calibration_force_on_gpu = (
        cfg.QUANTIZATION.PTQ.CALIBRATION_FORCE_ON_GPU and torch.cuda.is_available()
    )
    if calibration_force_on_gpu:
        # NOTE: model.to(device) may not handle cases such as normalizer, FPN, only
        # do move to GPU if specified.
        _cast_detection_model(model, "cuda")

    calibration_iters = cfg.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES
    for idx, inputs in enumerate(data_loader):
        # Setting CALIBRATION_NUM_IMAGES to 0 allows skipping calibration
        if idx == calibration_iters:
            break
        logger.info("Running calibration iter: {}/{}".format(idx, calibration_iters))

        if calibration_force_on_gpu:
            iters = recursive_iterate(inputs)
            for x in iters:
                if isinstance(x, torch.Tensor):
                    iters.send(x.to("cuda"))
            inputs = iters.value

        with torch.no_grad():
            model(inputs)
    else:
        logger.warning("Can't run enough calibration iterations")

    # cast model back to the original device
    if calibration_force_on_gpu:
        _cast_detection_model(model, cfg.MODEL.DEVICE)

    return model