def trainer()

in sdk/python/foundation-models/healthcare-ai/medimageinsight/classification_demo/adaptor_training/training.py [0:0]


def trainer(train_ds, test_ds, model, loss_function_ts, optimizer, epochs, root_dir):
    """
    Trains a classification model and evaluates it on a validation set.
    Saves the model with the best validation ROC AUC score.
    """

    start_time = time.time()

    max_epoch = epochs
    best_metric = -1
    best_acc = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []

    # Set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    for epoch in range(max_epoch):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{max_epoch}")
        model.train()
        epoch_loss = 0
        step = 0

        # Training loop
        for batch_idx, (features, pathology_label, img_name) in tqdm(
            enumerate(train_ds),
            total=len(train_ds),
            desc=f"Train Epoch={epoch}",
            ncols=80,
            leave=False,
        ):

            step += 1
            features = features.to(device)
            pathology_label = pathology_label.to(device)

            optimizer.zero_grad()
            _, pred_pathology = model(features)

            loss = loss_function_ts(pred_pathology, pathology_label)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

            print(f"{step}/{len(train_ds)}, train_loss: {loss.item():.4f}")

        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"Epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        # Validation loop
        model.eval()
        with torch.no_grad():
            y_pred_list = []
            y_true_list = []

            for batch_idx, (features, pathology_label, img_name) in tqdm(
                enumerate(test_ds),
                total=len(test_ds),
                desc=f"Test Epoch={epoch}",
                ncols=80,
                leave=False,
            ):

                features = features.to(device)
                pathology_label = pathology_label.to(device)

                _, pred_pathology = model(features)

                y_pred_list.append(pred_pathology)
                y_true_list.append(pathology_label)

            # Concatenate predictions and true labels
            y_pred = torch.cat(y_pred_list, dim=0)
            y_true = torch.cat(y_true_list, dim=0)

            # Compute probabilities for the positive class
            y_scores = torch.softmax(y_pred, dim=1).cpu().numpy()
            y_true_np = y_true.cpu().numpy()

            # Compute ROC AUC
            auc = roc_auc_score(y_true_np, y_scores, multi_class="ovr")

            # Compute accuracy
            acc_metric = (y_pred.argmax(dim=1) == y_true).sum().item() / len(y_true)

            metric_values.append(auc)

            # Save the best model
            if auc > best_metric:
                best_metric = auc
                best_acc = acc_metric
                best_metric_epoch = epoch + 1
                torch.save(
                    model.state_dict(), os.path.join(root_dir, "best_metric_model.pth")
                )
                print("Saved new best metric model")

            print(
                f"Current epoch: {epoch + 1} Current AUC: {auc:.4f}"
                f" Current accuracy: {acc_metric:.4f}"
                f" Best AUC: {best_metric:.4f}"
                f" Best accuracy: {best_acc:.4f}"
                f" at epoch: {best_metric_epoch}"
            )

    end_time = time.time()
    training_time = end_time - start_time
    hours, rem = divmod(training_time, 3600)
    minutes, seconds = divmod(rem, 60)
    print(f"Total Training Time: {int(hours):02}:{int(minutes):02}:{seconds:.2f}")
    print(
        f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    return best_acc, best_metric