def export_sklearn_head_to_onnx()

in src/setfit/exporters/onnx.py [0:0]


def export_sklearn_head_to_onnx(model_head: LogisticRegression, opset: int) -> onnx.onnx_ml_pb2.ModelProto:
    """Convert the Scikit-Learn head from a SetFitModel to ONNX format.

    Args:
        model_head (`LogisticRegression`): The trained SetFit model_head.
        opset (`int`): The ONNX opset to use for optimizing this model. The opset is not
            guaranteed and will default to the maximum version possible for the sklearn
            model.

    Returns:
        [`onnx.onnx_ml_pb2.ModelProto`] The ONNX model generated from the sklearn head.

    Raises:
        ImportError: If `skl2onnx` is not installed an error will be raised asking
            to install this package.
    """

    # Check if skl2onnx is installed
    try:
        import onnxconverter_common
        from skl2onnx import convert_sklearn
        from skl2onnx.common.data_types import guess_data_type
        from skl2onnx.sklapi import CastTransformer
        from sklearn.pipeline import Pipeline
    except ImportError:
        msg = """
        `skl2onnx` must be installed in order to convert a model with an sklearn head.
        Please install with `pip install skl2onnx`.
        """
        raise ImportError(msg)

    # Determine the initial type and the shape of the output.
    input_shape = (None, model_head.n_features_in_)
    if hasattr(model_head, "coef_"):
        dtype = guess_data_type(model_head.coef_, shape=input_shape)[0][1]
    elif not hasattr(model_head, "coef_") and hasattr(model_head, "estimators_"):
        if any([not hasattr(e, "coef_") for e in model_head.estimators_]):
            raise ValueError(
                "The model_head is a meta-estimator but not all of the estimators have a coef_ attribute."
            )
        dtype = guess_data_type(model_head.estimators_[0].coef_, shape=input_shape)[0][1]
    else:
        raise ValueError(
            "The model_head either does not have a coef_ attribute or some estimators in model_head.estimators_ do not have a coef_ attribute. Conversion to ONNX only supports these cases."
        )
    dtype.shape = input_shape

    # If the datatype of the model is double we need to cast the outputs
    # from the setfit model to doubles for compatibility inside of ONNX.
    if isinstance(dtype, onnxconverter_common.data_types.DoubleTensorType):
        sklearn_model = Pipeline([("castdouble", CastTransformer(dtype=np.double)), ("head", model_head)])
    else:
        sklearn_model = model_head

    # Convert sklearn head into ONNX format
    onnx_model = convert_sklearn(
        sklearn_model,
        initial_types=[("model_head", dtype)],
        target_opset=opset,
        options={id(sklearn_model): {"zipmap": False}},
    )

    return onnx_model