def load_model()

in 2_slm-fine-tuning-mlstudio/phi/src_train/train_mlflow.py [0:0]


def load_model(args):
    model_name_or_path = args.model_name_or_path
    model_kwargs = dict(
        use_cache=False,
        trust_remote_code=True,
        # attn_implementation="flash_attention_2",  # loading the model with flash-attenstion support
        torch_dtype=torch.bfloat16,
        device_map=None,
    )
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    tokenizer.model_max_length = args.max_seq_length
    tokenizer.pad_token = (
        tokenizer.unk_token
    )  # use unk rather than eos token to prevent endless generation
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    tokenizer.padding_side = "right"
    return model, tokenizer