def load_model()

in florence2-VQA/src_train/train_mlflow.py [0:0]


def load_model(model_name_or_path="microsoft/Florence-2-base-ft", freeze_vision_encoder=True):
    global model
    global processor
    
    model_kwargs = dict(
        trust_remote_code=True,
        revision="refs/pr/6",        
        device_map=device
    )
    
    processor_kwargs = dict(
        trust_remote_code=True,
        revision="refs/pr/6"
    )
    
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
    processor = AutoProcessor.from_pretrained(model_name_or_path, **processor_kwargs)
    
    if freeze_vision_encoder:
        for param in model.vision_tower.parameters():
            param.is_trainable = False