def get_model_memory_footprint()

in assets/96_hf_bitsandbytes_integration/example.py [0:0]


def get_model_memory_footprint(model):
    r"""
        Partially copied and inspired from: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
    """
    return sum([param.nelement() * param.element_size() for param in model.parameters()])