def predict()

in src/sagemaker_xgboost_container/algorithm_mode/serve_utils.py [0:0]


def predict(model, model_format, dtest, input_content_type, objective=None):
    bst, bst_format = (model[0], model_format[0]) if type(model) is list else (model, model_format)

    if bst_format == PKL_FORMAT:
        x = len(bst.feature_names)
        y = len(dtest.feature_names)

        try:
            content_type = get_content_type(input_content_type)
        except Exception:
            raise ValueError('Content type {} is not supported'.format(input_content_type))

        if content_type == LIBSVM:
            if y > x + 1:
                raise ValueError('Feature size of libsvm inference data {} is larger than '
                                 'feature size of trained model {}.'.format(y, x))
        elif content_type in [CSV, RECORDIO_PROTOBUF]:
            if not ((x == y) or (x == y + 1)):
                raise ValueError('Feature size of {} inference data {} is not consistent '
                                 'with feature size of trained model {}.'.
                                 format(content_type, y, x))
        else:
            raise ValueError('Content type {} is not supported'.format(content_type))

    if isinstance(model, list):
        ensemble = [booster.predict(dtest,
                                    ntree_limit=getattr(booster, "best_ntree_limit", 0),
                                    validate_features=False) for booster in model]

        if objective in [MULTI_SOFTMAX, BINARY_HINGE]:
            logging.info(f"Vote ensemble prediction of {objective} with {len(model)} models")
            return stats.mode(ensemble).mode[0]
        else:
            logging.info(f"Average ensemble prediction of {objective} with {len(model)} models")
            return np.mean(ensemble, axis=0)
    else:
        return model.predict(dtest,
                             ntree_limit=getattr(model, "best_ntree_limit", 0),
                             validate_features=False)