utils/check_inference_input_params.py (72 lines of code) (raw):

# coding=utf-8 # Copyright 2024-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility script to check consistency between input parameters of InferenceClient methods and generated types. TODO: check all methods TODO: check parameters types TODO: check parameters default values TODO: check parameters (type, description) are consistent in the docstrings TODO: (low priority) automatically generate the input types from the methods """ import inspect from dataclasses import is_dataclass from typing import Any, get_args from huggingface_hub import InferenceClient from huggingface_hub.inference._generated import types METHODS_TO_SKIP = [ # + all private methods "post", "conversational", ] PARAMETERS_TO_SKIP = { "chat_completion": {}, "text_generation": { "stop", # stop_sequence instead (for legacy reasons) }, } UNDOCUMENTED_PARAMETERS = { "self", } def check_method(method_name: str, method: Any): input_type_name = "".join(part.capitalize() for part in method_name.split("_")) + "Input" if not hasattr(types, input_type_name): return [f"Missing input type for method {method_name}"] input_type = getattr(types, input_type_name) docstring = method.__doc__ if method_name == "chat_completion": # Special case for chat_completion parameters_type = input_type else: parameters_field = input_type.__dataclass_fields__.get("parameters", None) if parameters_field is None: return [f"Missing 'parameters' field for type {input_type}"] parameters_type = get_args(parameters_field.type)[0] if not is_dataclass(parameters_type): return [f"'parameters' field is not a dataclass for type {input_type} ({parameters_type})"] # For each expected parameter, check it is defined logs = [] method_params = inspect.signature(method).parameters for param_name in parameters_type.__dataclass_fields__: if param_name in PARAMETERS_TO_SKIP[method_name]: continue if param_name not in method_params: logs.append(f"Missing parameter {param_name} in method signature") # Check parameter is documented in docstring for param_name in method_params: if param_name in UNDOCUMENTED_PARAMETERS: continue if param_name in PARAMETERS_TO_SKIP[method_name]: continue if f" {param_name} (" not in docstring: logs.append(f"Parameter {param_name} is not documented") return logs # Inspect InferenceClient methods individually exit_code = 0 all_logs = [] # print details only if errors are found for method_name, method in inspect.getmembers(InferenceClient, predicate=inspect.isfunction): if method_name.startswith("_") or method_name in METHODS_TO_SKIP: continue if method_name not in PARAMETERS_TO_SKIP: all_logs.append(f" ⏩️ {method_name}: skipped") continue logs = check_method(method_name, method) if len(logs) > 0: exit_code = 1 all_logs.append(f" ❌ {method_name}: errors found") all_logs.append("\n".join(" " * 4 + log for log in logs)) continue else: all_logs.append(f" ✅ {method_name}: success!") continue if exit_code == 0: print("✅ All good! (inference inputs)") else: print("❌ Inconsistency found in inference inputs.") for log in all_logs: print(log) exit(exit_code)