in distilvit/curate.py [0:0]
def load_model_and_tokenizer(self):
if platform.system() == "Darwin":
kw = {
"torch_dtype": torch.bfloat16,
"low_cpu_mem_usage": True,
"trust_remote_code": True,
}
bnb_config = None
else:
from transformers import BitsAndBytesConfig
kw = {
"trust_remote_code": True,
}
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=bnb_config, **kw
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)