def transform_fn()

in autogluon/tabular-prediction/AutoGluon-Tabular-with-SageMaker/container-training/inference.py [0:0]


def transform_fn(models, data, input_content_type, output_content_type):
    """
    Transform a request using the Gluon model. Called once per request.
    :param models: The Gluon model and the column info.
    :param data: The request payload.
    :param input_content_type: The request content type. ('text/csv')
    :param output_content_type: The (desired) response content type. ('text/csv')
    :return: response payload and content type.
    """
    start = timer()
    net = models[0]
    column_dict = models[1]

    # text/csv
    if input_content_type == 'text/csv':
        
        # Load dataset
        columns = column_dict['columns']
        df = pd.read_csv(StringIO(data), header=None)
        df_preprosessed = preprocess(df, columns, net.label_column)
        ds = task.Dataset(df=df_preprosessed)
        
        try:
            predictions = net.predict(ds)
        except:
            try:
                predictions = net.predict(ds.fillna(0.0))
                warnings.warn('Filled NaN\'s with 0.0 in order to predict.')
            except Exception as e:
                response_body = e
                return response_body, output_content_type
        
        # Print prediction counts, limit in case of regression problem
        pred_counts = Counter(predictions.tolist())
        n_display_items = 30
        if len(pred_counts) > n_display_items:
            print(f'Top {n_display_items} prediction counts: '
                  f'{dict(take(n_display_items, pred_counts.items()))}')
        else:
            print(f'Prediction counts: {pred_counts}')

        # Form response
        output = StringIO()
        pd.DataFrame(predictions).to_csv(output, header=False, index=False)
        response_body = output.getvalue() 

        # If target column passed, evaluate predictions performance
        target = net.label_column
        if target in ds:
            print(f'Label column ({target}) found in input data. '
                  'Therefore, evaluating prediction performance...')    
            try:
                performance = net.evaluate_predictions(y_true=ds[target], 
                                                       y_pred=predictions, 
                                                       auxiliary_metrics=True)                
                print(json.dumps(performance, indent=4, default=pd.DataFrame.to_json))
                time.sleep(0.1)
            except Exception as e:
                # Print exceptions on evaluate, continue to return predictions
                print(f'Exception: {e}')
    else:
        raise NotImplementedError("content_type must be 'text/csv'")

    elapsed_time = round(timer()-start,3)
    print(f'Elapsed time: {round(timer()-start,3)} seconds')           
    
    return response_body, output_content_type