def predict()

in src/sagemaker_xgboost_container/algorithm_mode/serve_utils.py [0:0]


def predict(model, model_format, dtest, input_content_type, objective=None):
    bst, bst_format = (model[0], model_format[0]) if type(model) is list else (model, model_format)

    if bst_format == PKL_FORMAT:
        x = bst.num_features()
        y = len(dtest.feature_names) if dtest.feature_names is not None else dtest.num_col()

        try:
            content_type = get_content_type(input_content_type)
        except Exception:
            raise ValueError("Content type {} is not supported".format(input_content_type))

        if content_type == LIBSVM:
            if y > x + 1:
                raise ValueError(
                    "Feature size of libsvm inference data {} is larger than "
                    "feature size of trained model {}.".format(y, x)
                )
        elif content_type in [CSV, RECORDIO_PROTOBUF]:
            if not ((x == y) or (x == y + 1)):
                raise ValueError(
                    "Feature size of {} inference data {} is not consistent "
                    "with feature size of trained model {}.".format(content_type, y, x)
                )
        else:
            raise ValueError("Content type {} is not supported".format(content_type))

    if isinstance(model, list):
        ensemble = [
            booster.predict(dtest, ntree_limit=getattr(booster, "best_ntree_limit", 0), validate_features=False)
            for booster in model
        ]

        if objective in [MULTI_SOFTMAX, BINARY_HINGE]:
            logging.info(f"Vote ensemble prediction of {objective} with {len(model)} models")
            return stats.mode(ensemble).mode[0]
        else:
            logging.info(f"Average ensemble prediction of {objective} with {len(model)} models")
            return np.mean(ensemble, axis=0)
    else:
        return model.predict(dtest, ntree_limit=getattr(model, "best_ntree_limit", 0), validate_features=False)