def show_classification_report_confusion_matrix()

in autogluon/tabular-prediction/AutoGluon-Tabular-with-SageMaker/utils/ag_utils.py [0:0]


def show_classification_report_confusion_matrix(Job_Name):
    
    classificationreport_fname = f'./tmp/{Job_Name}/classification_report.csv'
    featureimportance_fname = f'./tmp/{Job_Name}/feature_importance.csv'
    confusionmatrix_fname = f'./tmp/{Job_Name}/confusion_matrix.png'
    roc_auc_curve_fname = f'./tmp/{Job_Name}/roc_auc_curve.png'
    
    has_data = False
    
    # Classification report
    if os.path.exists(classificationreport_fname):
        df = pd.read_csv(classificationreport_fname)
        df = df.rename(columns={'Unnamed: 0': 'Label'})
        df.set_index('Label', inplace=True)
        df = df.applymap('{0:.2f}'.format)
        df_html = df.style.set_table_attributes('class="table"').render()
        cr_widget_html = HTML(df_html)
        has_data = True
    else:
        cr_widget_html = VBox([])
        
    # Feature importance
    if os.path.exists(featureimportance_fname):
        df = pd.read_csv(featureimportance_fname)
        df.set_index(df.columns[0], inplace=True)
        df = df.applymap('{0:.3f}'.format)
        df_html = df.style.set_table_attributes('class="table"').render()
        fi_widget_html = HTML(df_html)
        has_data = True
    else:
        fi_widget_html = VBox([])
        
    cr_title_html = get_html_text('Classification report', '#ff9900')
    fi_title_html = get_html_text('Feature importance', '#ff9900')
    widget_tables = VBox([cr_title_html, cr_widget_html, fi_title_html, fi_widget_html]) 
        
    # Confusion matrix    
    if os.path.exists(confusionmatrix_fname):
        img_file = open(confusionmatrix_fname, 'rb')
        image = img_file.read()
        widget_cm_img = Image(value=image, format='png')
        has_data = True
    else:  
        widget_cm_img = VBox([])
         
    # ROC curve    
    if os.path.exists(roc_auc_curve_fname):
        img_file = open(roc_auc_curve_fname, 'rb')
        image = img_file.read()
        widget_roc_img = Image(value=image, format='png')
        has_data = True
    else:  
        widget_roc_img = VBox([])

    cm_title_html = get_html_text('Confusion matrix', '#ff9900')
    roc_title_html = get_html_text('ROC Curve', '#ff9900')
    widget_imgs = VBox([cm_title_html, widget_cm_img, roc_title_html, widget_roc_img], layout=Layout(margin='0 0 0 10px'))  

    if has_data:    
        show_in_html('Model analysis on test dataset', '#000099')
        
        grid = GridspecLayout(1, 2)
        grid[0, 0] = widget_tables
        grid[0, 1] = widget_imgs
        display(grid)