assets/96_hf_bitsandbytes_integration/example.py (22 lines of code) (raw):
import torch
import torch.nn as nn
from bitsandbytes.nn import Linear8bitLt
# Utility function
def get_model_memory_footprint(model):
r"""
Partially copied and inspired from: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
"""
return sum([param.nelement() * param.element_size() for param in model.parameters()])
# Main script
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
).to(torch.float16)
# Train and save your model!
torch.save(fp16_model.state_dict(), "model.pt")
# Define your int8 model!
int8_model = nn.Sequential(
Linear8bitLt(64, 64, has_fp16_weights=False),
Linear8bitLt(64, 64, has_fp16_weights=False)
)
int8_model.load_state_dict(torch.load("model.pt"))
int8_model = int8_model.to(0) # Quantization happens here
input_ = torch.randn(8, 64, dtype=torch.float16)
hidden_states = int8_model(input_.to(0))
mem_int8 = get_model_memory_footprint(int8_model)
mem_fp16 = get_model_memory_footprint(fp16_model)
print(f"Relative difference: {mem_fp16/mem_int8}")