in florence2-VQA/src_train/train_mlflow.py [0:0]
def load_model(model_name_or_path="microsoft/Florence-2-base-ft", freeze_vision_encoder=True):
global model
global processor
model_kwargs = dict(
trust_remote_code=True,
revision="refs/pr/6",
device_map=device
)
processor_kwargs = dict(
trust_remote_code=True,
revision="refs/pr/6"
)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
processor = AutoProcessor.from_pretrained(model_name_or_path, **processor_kwargs)
if freeze_vision_encoder:
for param in model.vision_tower.parameters():
param.is_trainable = False