api_inference_community/validation.py (338 lines of code) (raw):
import json
import os
import subprocess
from base64 import b64decode
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
import annotated_types
import numpy as np
from pydantic import BaseModel, RootModel, Strict, field_validator
from typing_extensions import Annotated
MinLength = Annotated[int, annotated_types.Ge(1), annotated_types.Le(500), Strict()]
MaxLength = Annotated[int, annotated_types.Ge(1), annotated_types.Le(500), Strict()]
TopK = Annotated[int, annotated_types.Ge(1), Strict()]
TopP = Annotated[float, annotated_types.Ge(0.0), annotated_types.Le(1.0), Strict()]
MaxTime = Annotated[float, annotated_types.Ge(0.0), annotated_types.Le(120.0), Strict()]
NumReturnSequences = Annotated[
int, annotated_types.Ge(1), annotated_types.Le(10), Strict()
]
RepetitionPenalty = Annotated[
float, annotated_types.Ge(0.0), annotated_types.Le(100.0), Strict()
]
Temperature = Annotated[
float, annotated_types.Ge(0.0), annotated_types.Le(100.0), Strict()
]
CandidateLabels = Annotated[list, annotated_types.MinLen(1)]
class FillMaskParamsCheck(BaseModel):
top_k: Optional[TopK] = None
class ZeroShotParamsCheck(BaseModel):
candidate_labels: Union[str, CandidateLabels]
multi_label: Optional[bool] = None
class SharedGenerationParams(BaseModel):
min_length: Optional[MinLength] = None
max_length: Optional[MaxLength] = None
top_k: Optional[TopK] = None
top_p: Optional[TopP] = None
max_time: Optional[MaxTime] = None
repetition_penalty: Optional[RepetitionPenalty] = None
temperature: Optional[Temperature] = None
@field_validator("max_length")
def max_length_must_be_larger_than_min_length(
cls, max_length: Optional[MaxLength], values
):
min_length = values.data.get("min_length", 0)
if min_length is None:
min_length = 0
if max_length is not None and max_length < min_length:
raise ValueError("min_length cannot be larger than max_length")
return max_length
class TextGenerationParamsCheck(SharedGenerationParams):
return_full_text: Optional[bool] = None
num_return_sequences: Optional[NumReturnSequences] = None
class SummarizationParamsCheck(SharedGenerationParams):
num_return_sequences: Optional[NumReturnSequences] = None
class ConversationalInputsCheck(BaseModel):
text: str
past_user_inputs: List[str]
generated_responses: List[str]
class QuestionInputsCheck(BaseModel):
question: str
context: str
class SentenceSimilarityInputsCheck(BaseModel):
source_sentence: str
sentences: List[str]
class TableQuestionAnsweringInputsCheck(BaseModel):
table: Dict[str, List[str]]
query: str
@field_validator("table")
def all_rows_must_have_same_length(cls, table: Dict[str, List[str]]):
rows = list(table.values())
n = len(rows[0])
if all(len(x) == n for x in rows):
return table
raise ValueError("All rows in the table must be the same length")
class TabularDataInputsCheck(BaseModel):
data: Dict[str, List[str]]
@field_validator("data")
def all_rows_must_have_same_length(cls, data: Dict[str, List[str]]):
rows = list(data.values())
n = len(rows[0])
if all(len(x) == n for x in rows):
return data
raise ValueError("All rows in the data must be the same length")
class StringOrStringBatchInputCheck(RootModel):
root: Union[List[str], str]
@field_validator("root")
def input_must_not_be_empty(cls, root: Union[List[str], str]):
if isinstance(root, list):
if len(root) == 0:
raise ValueError(
"The inputs are invalid, at least one input is required"
)
return root
class StringInput(RootModel):
root: str
PARAMS_MAPPING = {
"conversational": SharedGenerationParams,
"fill-mask": FillMaskParamsCheck,
"text2text-generation": TextGenerationParamsCheck,
"text-generation": TextGenerationParamsCheck,
"summarization": SummarizationParamsCheck,
"zero-shot-classification": ZeroShotParamsCheck,
}
INPUTS_MAPPING = {
"conversational": ConversationalInputsCheck,
"question-answering": QuestionInputsCheck,
"feature-extraction": StringOrStringBatchInputCheck,
"sentence-similarity": SentenceSimilarityInputsCheck,
"table-question-answering": TableQuestionAnsweringInputsCheck,
"tabular-classification": TabularDataInputsCheck,
"tabular-regression": TabularDataInputsCheck,
"fill-mask": StringInput,
"summarization": StringInput,
"text2text-generation": StringInput,
"text-generation": StringInput,
"text-classification": StringInput,
"token-classification": StringInput,
"translation": StringInput,
"zero-shot-classification": StringInput,
"text-to-speech": StringInput,
"text-to-image": StringInput,
}
BATCH_ENABLED_PIPELINES = ["feature-extraction"]
def check_params(params, tag):
if tag in PARAMS_MAPPING:
PARAMS_MAPPING[tag].model_validate(params)
return True
def check_inputs(inputs, tag):
if tag in INPUTS_MAPPING:
INPUTS_MAPPING[tag].model_validate(inputs)
return True
else:
raise ValueError(f"{tag} is not a valid pipeline.")
AUDIO_INPUTS = {
"automatic-speech-recognition",
"audio-to-audio",
"speech-segmentation",
"audio-classification",
}
AUDIO_OUTPUTS = {
"audio-to-audio",
"text-to-speech",
}
IMAGE_INPUTS = {
"image-classification",
"image-segmentation",
"image-to-text",
"image-to-image",
"object-detection",
"zero-shot-image-classification",
}
IMAGE_OUTPUTS = {
"image-to-image",
"text-to-image",
"latent-to-image",
}
TENSOR_INPUTS = {
"latent-to-image",
}
TEXT_INPUTS = {
"conversational",
"feature-extraction",
"question-answering",
"sentence-similarity",
"fill-mask",
"table-question-answering",
"tabular-classification",
"tabular-regression",
"summarization",
"text-generation",
"text2text-generation",
"text-classification",
"text-to-image",
"text-to-speech",
"token-classification",
"zero-shot-classification",
}
KNOWN_TASKS = AUDIO_INPUTS.union(IMAGE_INPUTS).union(TEXT_INPUTS).union(TENSOR_INPUTS)
AUDIO = [
"flac",
"ogg",
"mp3",
"wav",
"m4a",
"aac",
"webm",
]
IMAGE = [
"jpeg",
"png",
"webp",
"tiff",
"bmp",
]
def parse_accept(accept: str, accepted: List[str]) -> str:
for mimetype in accept.split(","):
# remove quality
mimetype = mimetype.split(";")[0]
# remove prefix
extension = mimetype.split("/")[-1]
if extension in accepted:
return extension
return accepted[0]
def normalize_payload(
bpayload: bytes, task: str, sampling_rate: Optional[int]
) -> Tuple[Any, Dict]:
if task in AUDIO_INPUTS:
if sampling_rate is None:
raise EnvironmentError(
"We cannot normalize audio file if we don't know the sampling rate"
)
return normalize_payload_audio(bpayload, sampling_rate)
elif task in IMAGE_INPUTS:
return normalize_payload_image(bpayload)
elif task in TEXT_INPUTS:
return normalize_payload_nlp(bpayload, task)
elif task in TENSOR_INPUTS:
return normalize_payload_tensor(bpayload)
else:
raise EnvironmentError(
f"The task `{task}` is not recognized by api-inference-community"
)
def ffmpeg_convert(
array: np.array, sampling_rate: int, format_for_conversion: str
) -> bytes:
"""
Helper function to convert raw waveforms to actual compressed file (lossless compression here)
"""
ar = str(sampling_rate)
ac = "1"
ffmpeg_command = [
"ffmpeg",
"-ac",
"1",
"-f",
"f32le",
"-ac",
ac,
"-ar",
ar,
"-i",
"pipe:0",
"-f",
format_for_conversion,
"-hide_banner",
"-loglevel",
"quiet",
"pipe:1",
]
ffmpeg_process = subprocess.Popen(
ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE
)
output_stream = ffmpeg_process.communicate(array.tobytes())
out_bytes = output_stream[0]
if len(out_bytes) == 0:
raise Exception("Impossible to convert output stream")
return out_bytes
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
"""
Librosa does that under the hood but forces the use of an actual
file leading to hitting disk, which is almost always very bad.
"""
ar = f"{sampling_rate}"
ac = "1"
format_for_conversion = "f32le"
ffmpeg_command = [
"ffmpeg",
"-i",
"pipe:0",
"-ac",
ac,
"-ar",
ar,
"-f",
format_for_conversion,
"-hide_banner",
"-loglevel",
"quiet",
"pipe:1",
]
ffmpeg_process = subprocess.Popen(
ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE
)
output_stream = ffmpeg_process.communicate(bpayload)
out_bytes = output_stream[0]
audio = np.frombuffer(out_bytes, np.float32).copy()
if audio.shape[0] == 0:
raise ValueError("Malformed soundfile")
return audio
def normalize_payload_image(bpayload: bytes) -> Tuple[Any, Dict]:
from PIL import Image
try:
# We accept both binary image with mimetype
# and {"inputs": base64encodedimage}
data = json.loads(bpayload)
image = data["image"] if "image" in data else data["inputs"]
image_bytes = b64decode(image)
img = Image.open(BytesIO(image_bytes))
return img, data.get("parameters", {})
except Exception:
pass
img = Image.open(BytesIO(bpayload))
return img, {}
DATA_PREFIX = os.getenv("HF_TRANSFORMERS_CACHE", "")
def normalize_payload_audio(bpayload: bytes, sampling_rate: int) -> Tuple[Any, Dict]:
if os.path.isfile(bpayload) and bpayload.startswith(DATA_PREFIX.encode("utf-8")):
# XXX:
# This is necessary for batch jobs where the datasets can contain
# filenames instead of the raw data.
# We attempt to sanitize this roughly, by checking it lives on the data
# path (hardcoded in the deployment and in all the dockerfiles)
# We also attempt to prevent opening files that are not obviously
# audio files, to prevent opening stuff like model weights.
filename, ext = os.path.splitext(bpayload)
if ext.decode("utf-8")[1:] in AUDIO:
with open(bpayload, "rb") as f:
bpayload = f.read()
inputs = ffmpeg_read(bpayload, sampling_rate)
if len(inputs.shape) > 1:
# ogg can take dual channel input -> take only first input channel in this case
inputs = inputs[:, 0]
return inputs, {}
def normalize_payload_nlp(bpayload: bytes, task: str) -> Tuple[Any, Dict]:
payload = bpayload.decode("utf-8")
# We used to accept raw strings, we need to maintain backward compatibility
try:
payload = json.loads(payload)
if isinstance(payload, (float, int)):
payload = str(payload)
except Exception:
pass
parameters: Dict[str, Any] = {}
if isinstance(payload, dict) and "inputs" in payload:
inputs = payload["inputs"]
parameters = payload.get("parameters", {})
else:
inputs = payload
check_params(parameters, task)
check_inputs(inputs, task)
return inputs, parameters
def normalize_payload_tensor(bpayload: bytes) -> Tuple[Any, Dict]:
import torch
data = json.loads(bpayload)
tensor = data["inputs"]
tensor = b64decode(tensor.encode("utf-8"))
parameters = data.get("parameters", {})
if "shape" not in parameters:
raise ValueError("Expected `shape` in parameters.")
if "dtype" not in parameters:
raise ValueError("Expected `dtype` in parameters.")
DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
shape = parameters.pop("shape")
dtype = DTYPE_MAP.get(parameters.pop("dtype"))
tensor = torch.frombuffer(bytearray(tensor), dtype=dtype).reshape(shape)
return tensor, parameters