def train_model()

in florence2-VQA/src_train/train_mlflow.py [0:0]


def train_model(args, train_dataset, val_dataset):
    epochs = args.epochs
    save_steps = args.save_steps
    grad_accum_steps = args.grad_accum_steps
    
    train_batch_size = args.train_batch_size
    eval_batch_size = args.eval_batch_size
    num_workers = 0

    # Create dataloader
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, collate_fn=partial(collate_fn, processor=processor), num_workers=num_workers, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, collate_fn=partial(collate_fn, processor=processor), num_workers=num_workers)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)
    num_training_steps = epochs * len(train_loader)
    num_warmup_steps = int(num_training_steps * args.warmup_ratio)
    
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

    saved_models = []
    model.train() 
    
    with mlflow.start_run() as run: 
    
        mlflow.log_params({
            "epochs": epochs,
            "train_batch_size": args.train_batch_size,
            "eval_batch_size": args.eval_batch_size,
            "seed": args.seed,
            "lr_scheduler_type": args.lr_scheduler_type,        
            "grad_accum_steps": grad_accum_steps, 
            "num_training_steps": num_training_steps,
            "num_warmup_steps": num_warmup_steps,
        })

        for epoch in range(epochs):     
            train_loss = 0.0
            optimizer.zero_grad()

            for step, (inputs, answers) in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}")):

                input_ids = inputs["input_ids"]
                pixel_values = inputs["pixel_values"] 
                labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)

                outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
                loss = outputs.loss
                loss.backward()
                train_loss += loss.item()           

                if (step + 1) % grad_accum_steps == 0:
                    train_loss /= grad_accum_steps # compute gradient average  
                    learning_rate = lr_scheduler.get_last_lr()[0]
                    progress = (step+1)/len(train_loader)
                    print(f'Epoch [{epoch+1}/{epochs}], Step [{step+1}/{len(train_loader)}], Learning Rate: {learning_rate}, Loss: {train_loss}')
                    mlflow.log_metric("train_loss", train_loss)
                    mlflow.log_metric("learning_rate", learning_rate)
                    mlflow.log_metric("progress", progress)
                    
                    optimizer.step()
                    optimizer.zero_grad()
                    lr_scheduler.step()
                    train_loss = 0.0

                if (step + 1) % save_steps == 0:
                    output_dir = f"./{args.output_dir}/steps_{step+1}"
                    os.makedirs(output_dir, exist_ok=True)
                    model.save_pretrained(output_dir)
                    processor.save_pretrained(output_dir)                
                    print(f'Model saved at step {step+1} of epoch {epoch+1}')
                    saved_models.append(output_dir)

                    # Log image
                    idx = random.randrange(len(val_dataset))
                    val_img = val_dataset[idx][-1]
                    result = run_example("DocVQA", 'What do you see in this image?', val_dataset[idx][-1])
                    val_img_result = create_image_with_text(val_img, json.dumps(result))
                    mlflow.log_image(val_img_result, key="DocVQA", step=step)

                    # Manage to save only the most recent 3 checkpoints
                    if len(saved_models) > 2:
                        old_model = saved_models.pop(0)
                        if os.path.exists(old_model):
                            shutil.rmtree(old_model)
                            print(f'Removed old model: {old_model}')

            # Validation phase
            model.eval()
            val_loss = 0

            with torch.no_grad():
                for (inputs, answers) in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):

                    input_ids = inputs["input_ids"]
                    pixel_values = inputs["pixel_values"]
                    labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)

                    outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
                    loss = outputs.loss

                    val_loss += loss.item()

            avg_val_loss = val_loss / len(val_loader)
            
            mlflow.log_metric("avg_val_loss", avg_val_loss)
            print(f"Average Validation Loss: {avg_val_loss}")
            
        # Save model checkpoint
        model_dir = args.model_dir
        #os.makedirs(model_dir, exist_ok=True)
        model.save_pretrained(model_dir)
        processor.save_pretrained(model_dir)
        
        dependencies_dir = "dependencies"
        shutil.copytree(dependencies_dir, model_dir, dirs_exist_ok=True)