in src/pipelines.js [2685:3497]
box: get_bounding_box(box, !percentage),
}))
}
result.sort((a, b) => b.score - a.score);
if (top_k !== null) {
result = result.slice(0, top_k);
}
toReturn.push(result)
}
return isBatched ? toReturn : toReturn[0];
}
}
/**
* @typedef {Object} DocumentQuestionAnsweringSingle
* @property {string} answer The generated text.
* @typedef {DocumentQuestionAnsweringSingle[]} DocumentQuestionAnsweringOutput
*
* @callback DocumentQuestionAnsweringPipelineCallback Answer the question given as input by using the document.
* @param {ImageInput} image The image of the document to use.
* @param {string} question A question to ask of the document.
* @param {Partial<import('./generation/configuration_utils.js').GenerationConfig>} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise<DocumentQuestionAnsweringOutput|DocumentQuestionAnsweringOutput[]>} An object (or array of objects) containing the answer(s).
*
* @typedef {TextImagePipelineConstructorArgs & DocumentQuestionAnsweringPipelineCallback & Disposable} DocumentQuestionAnsweringPipelineType
*/
/**
* Document Question Answering pipeline using any `AutoModelForDocumentQuestionAnswering`.
* The inputs/outputs are similar to the (extractive) question answering pipeline; however,
* the pipeline takes an image (and optional OCR'd words/boxes) as input instead of text context.
*
* **Example:** Answer questions about a document with `Xenova/donut-base-finetuned-docvqa`.
* ```javascript
* const qa_pipeline = await pipeline('document-question-answering', 'Xenova/donut-base-finetuned-docvqa');
* const image = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/invoice.png';
* const question = 'What is the invoice number?';
* const output = await qa_pipeline(image, question);
* // [{ answer: 'us-001' }]
* ```
*/
export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options: TextImagePipelineConstructorArgs) => DocumentQuestionAnsweringPipelineType} */ (Pipeline)) {
/**
* Create a new DocumentQuestionAnsweringPipeline.
* @param {TextImagePipelineConstructorArgs} options An object used to instantiate the pipeline.
*/
constructor(options) {
super(options);
}
/** @type {DocumentQuestionAnsweringPipelineCallback} */
async _call(image, question, generate_kwargs = {}) {
// NOTE: For now, we only support a batch size of 1
// Preprocess image
const preparedImage = (await prepareImages(image))[0];
const { pixel_values } = await this.processor(preparedImage);
// Run tokenization
const task_prompt = `<s_docvqa><s_question>${question}</s_question><s_answer>`;
const decoder_input_ids = this.tokenizer(task_prompt, {
add_special_tokens: false,
padding: true,
truncation: true,
}).input_ids;
// Run model
const output = await this.model.generate({
inputs: pixel_values,
// @ts-expect-error TS2339
max_length: this.model.config.decoder.max_position_embeddings,
decoder_input_ids,
...generate_kwargs,
});
// Decode output
const decoded = this.tokenizer.batch_decode(/** @type {Tensor} */(output))[0];
// Parse answer
const match = decoded.match(/<s_answer>(.*?)<\/s_answer>/);
let answer = null;
if (match && match.length >= 2) {
answer = match[1].trim();
}
return [{ answer }];
}
}
/**
* @typedef {Object} VocoderOptions
* @property {PreTrainedModel} [vocoder] The vocoder used by the pipeline (if the model uses one). If not provided, use the default HifiGan vocoder.
* @typedef {TextAudioPipelineConstructorArgs & VocoderOptions} TextToAudioPipelineConstructorArgs
*/
/**
* @typedef {Object} TextToAudioOutput
* @property {Float32Array} audio The generated audio waveform.
* @property {number} sampling_rate The sampling rate of the generated audio waveform.
*
* @typedef {Object} TextToAudioPipelineOptions Parameters specific to text-to-audio pipelines.
* @property {Tensor|Float32Array|string|URL} [speaker_embeddings=null] The speaker embeddings (if the model requires it).
*
* @callback TextToAudioPipelineCallback Generates speech/audio from the inputs.
* @param {string|string[]} texts The text(s) to generate.
* @param {TextToAudioPipelineOptions} options Parameters passed to the model generation/forward method.
* @returns {Promise<TextToAudioOutput>} An object containing the generated audio and sampling rate.
*
* @typedef {TextToAudioPipelineConstructorArgs & TextToAudioPipelineCallback & Disposable} TextToAudioPipelineType
*/
/**
* Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`.
* This pipeline generates an audio file from an input text and optional other conditional inputs.
*
* **Example:** Generate audio from text with `Xenova/speecht5_tts`.
* ```javascript
* const synthesizer = await pipeline('text-to-speech', 'Xenova/speecht5_tts', { quantized: false });
* const speaker_embeddings = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin';
* const out = await synthesizer('Hello, my dog is cute', { speaker_embeddings });
* // RawAudio {
* // audio: Float32Array(26112) [-0.00005657337896991521, 0.00020583874720614403, ...],
* // sampling_rate: 16000
* // }
* ```
*
* You can then save the audio to a .wav file with the `wavefile` package:
* ```javascript
* import wavefile from 'wavefile';
* import fs from 'fs';
*
* const wav = new wavefile.WaveFile();
* wav.fromScratch(1, out.sampling_rate, '32f', out.audio);
* fs.writeFileSync('out.wav', wav.toBuffer());
* ```
*
* **Example:** Multilingual speech generation with `Xenova/mms-tts-fra`. See [here](https://huggingface.co/models?pipeline_tag=text-to-speech&other=vits&sort=trending) for the full list of available languages (1107).
* ```javascript
* const synthesizer = await pipeline('text-to-speech', 'Xenova/mms-tts-fra');
* const out = await synthesizer('Bonjour');
* // RawAudio {
* // audio: Float32Array(23808) [-0.00037693005288019776, 0.0003325853613205254, ...],
* // sampling_rate: 16000
* // }
* ```
*/
export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPipelineConstructorArgs) => TextToAudioPipelineType} */ (Pipeline)) {
DEFAULT_VOCODER_ID = "Xenova/speecht5_hifigan"
/**
* Create a new TextToAudioPipeline.
* @param {TextToAudioPipelineConstructorArgs} options An object used to instantiate the pipeline.
*/
constructor(options) {
super(options);
// TODO: Find a better way for `pipeline` to set the default vocoder
this.vocoder = options.vocoder ?? null;
}
/** @type {TextToAudioPipelineCallback} */
async _call(text_inputs, {
speaker_embeddings = null,
} = {}) {
// If this.processor is not set, we are using a `AutoModelForTextToWaveform` model
if (this.processor) {
return this._call_text_to_spectrogram(text_inputs, { speaker_embeddings });
} else {
return this._call_text_to_waveform(text_inputs);
}
}
async _call_text_to_waveform(text_inputs) {
// Run tokenization
const inputs = this.tokenizer(text_inputs, {
padding: true,
truncation: true,
});
// Generate waveform
const { waveform } = await this.model(inputs);
// @ts-expect-error TS2339
const sampling_rate = this.model.config.sampling_rate;
return new RawAudio(
waveform.data,
sampling_rate,
)
}
async _call_text_to_spectrogram(text_inputs, { speaker_embeddings }) {
// Load vocoder, if not provided
if (!this.vocoder) {
console.log('No vocoder specified, using default HifiGan vocoder.');
this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' });
}
// Load speaker embeddings as Float32Array from path/URL
if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) {
// Load from URL with fetch
speaker_embeddings = new Float32Array(
await (await fetch(speaker_embeddings)).arrayBuffer()
);
}
if (speaker_embeddings instanceof Float32Array) {
speaker_embeddings = new Tensor(
'float32',
speaker_embeddings,
[1, speaker_embeddings.length]
)
} else if (!(speaker_embeddings instanceof Tensor)) {
throw new Error("Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.")
}
// Run tokenization
const { input_ids } = this.tokenizer(text_inputs, {
padding: true,
truncation: true,
});
// NOTE: At this point, we are guaranteed that `speaker_embeddings` is a `Tensor`
// @ts-ignore
const { waveform } = await this.model.generate_speech(input_ids, speaker_embeddings, { vocoder: this.vocoder });
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
return new RawAudio(
waveform.data,
sampling_rate,
)
}
}
/**
* @callback ImageToImagePipelineCallback Transform the image(s) passed as inputs.
* @param {ImagePipelineInputs} images The images to transform.
* @returns {Promise<RawImage|RawImage[]>} The transformed image or list of images.
*
* @typedef {ImagePipelineConstructorArgs & ImageToImagePipelineCallback & Disposable} ImageToImagePipelineType
*/
/**
* Image to Image pipeline using any `AutoModelForImageToImage`. This pipeline generates an image based on a previous image input.
*
* **Example:** Super-resolution w/ `Xenova/swin2SR-classical-sr-x2-64`
* ```javascript
* const upscaler = await pipeline('image-to-image', 'Xenova/swin2SR-classical-sr-x2-64');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/butterfly.jpg';
* const output = await upscaler(url);
* // RawImage {
* // data: Uint8Array(786432) [ 41, 31, 24, 43, ... ],
* // width: 512,
* // height: 512,
* // channels: 3
* // }
* ```
*/
export class ImageToImagePipeline extends (/** @type {new (options: ImagePipelineConstructorArgs) => ImageToImagePipelineType} */ (Pipeline)) {
/**
* Create a new ImageToImagePipeline.
* @param {ImagePipelineConstructorArgs} options An object used to instantiate the pipeline.
*/
constructor(options) {
super(options);
}
/** @type {ImageToImagePipelineCallback} */
async _call(images) {
const preparedImages = await prepareImages(images);
const inputs = await this.processor(preparedImages);
const outputs = await this.model(inputs);
/** @type {RawImage[]} */
const toReturn = [];
for (const batch of outputs.reconstruction) {
const output = batch.squeeze().clamp_(0, 1).mul_(255).round_().to('uint8');
toReturn.push(RawImage.fromTensor(output));
}
return toReturn.length > 1 ? toReturn : toReturn[0];
}
}
/**
* @typedef {Object} DepthEstimationPipelineOutput
* @property {Tensor} predicted_depth The raw depth map predicted by the model.
* @property {RawImage} depth The processed depth map as an image (with the same size as the input image).
*
* @callback DepthEstimationPipelineCallback Predicts the depth for the image(s) passed as inputs.
* @param {ImagePipelineInputs} images The images to compute depth for.
* @returns {Promise<DepthEstimationPipelineOutput|DepthEstimationPipelineOutput[]>} An image or a list of images containing result(s).
*
* @typedef {ImagePipelineConstructorArgs & DepthEstimationPipelineCallback & Disposable} DepthEstimationPipelineType
*/
/**
* Depth estimation pipeline using any `AutoModelForDepthEstimation`. This pipeline predicts the depth of an image.
*
* **Example:** Depth estimation w/ `Xenova/dpt-hybrid-midas`
* ```javascript
* const depth_estimator = await pipeline('depth-estimation', 'Xenova/dpt-hybrid-midas');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg';
* const out = await depth_estimator(url);
* // {
* // predicted_depth: Tensor {
* // dims: [ 384, 384 ],
* // type: 'float32',
* // data: Float32Array(147456) [ 542.859130859375, 545.2833862304688, 546.1649169921875, ... ],
* // size: 147456
* // },
* // depth: RawImage {
* // data: Uint8Array(307200) [ 86, 86, 86, ... ],
* // width: 640,
* // height: 480,
* // channels: 1
* // }
* // }
* ```
*/
export class DepthEstimationPipeline extends (/** @type {new (options: ImagePipelineConstructorArgs) => DepthEstimationPipelineType} */ (Pipeline)) {
/**
* Create a new DepthEstimationPipeline.
* @param {ImagePipelineConstructorArgs} options An object used to instantiate the pipeline.
*/
constructor(options) {
super(options);
}
/** @type {DepthEstimationPipelineCallback} */
async _call(images) {
const preparedImages = await prepareImages(images);
const inputs = await this.processor(preparedImages);
const { predicted_depth } = await this.model(inputs);
const toReturn = [];
for (let i = 0; i < preparedImages.length; ++i) {
const batch = predicted_depth[i];
const [height, width] = batch.dims.slice(-2);
const [new_width, new_height] = preparedImages[i].size;
// Interpolate to original size
const prediction = (await interpolate_4d(batch.view(1, 1, height, width), {
size: [new_height, new_width],
mode: 'bilinear',
})).view(new_height, new_width);
const minval = /** @type {number} */(prediction.min().item());
const maxval = /** @type {number} */(prediction.max().item());
const formatted = prediction.sub(minval).div_(maxval - minval).mul_(255).to('uint8').unsqueeze(0);
const depth = RawImage.fromTensor(formatted);
toReturn.push({
predicted_depth: prediction,
depth,
});
}
return toReturn.length > 1 ? toReturn : toReturn[0];
}
}
const SUPPORTED_TASKS = Object.freeze({
"text-classification": {
"tokenizer": AutoTokenizer,
"pipeline": TextClassificationPipeline,
"model": AutoModelForSequenceClassification,
"default": {
// TODO: replace with original
// "model": "distilbert-base-uncased-finetuned-sst-2-english",
"model": "Xenova/distilbert-base-uncased-finetuned-sst-2-english",
},
"type": "text",
},
"token-classification": {
"tokenizer": AutoTokenizer,
"pipeline": TokenClassificationPipeline,
"model": AutoModelForTokenClassification,
"default": {
// TODO: replace with original
// "model": "Davlan/bert-base-multilingual-cased-ner-hrl",
"model": "Xenova/bert-base-multilingual-cased-ner-hrl",
},
"type": "text",
},
"question-answering": {
"tokenizer": AutoTokenizer,
"pipeline": QuestionAnsweringPipeline,
"model": AutoModelForQuestionAnswering,
"default": {
// TODO: replace with original
// "model": "distilbert-base-cased-distilled-squad",
"model": "Xenova/distilbert-base-cased-distilled-squad",
},
"type": "text",
},
"fill-mask": {
"tokenizer": AutoTokenizer,
"pipeline": FillMaskPipeline,
"model": AutoModelForMaskedLM,
"default": {
// TODO: replace with original
// "model": "bert-base-uncased",
"model": "Xenova/bert-base-uncased",
},
"type": "text",
},
"summarization": {
"tokenizer": AutoTokenizer,
"pipeline": SummarizationPipeline,
"model": AutoModelForSeq2SeqLM,
"default": {
// TODO: replace with original
// "model": "sshleifer/distilbart-cnn-6-6",
"model": "Xenova/distilbart-cnn-6-6",
},
"type": "text",
},
"translation": {
"tokenizer": AutoTokenizer,
"pipeline": TranslationPipeline,
"model": AutoModelForSeq2SeqLM,
"default": {
// TODO: replace with original
// "model": "t5-small",
"model": "Xenova/t5-small",
},
"type": "text",
},
"text2text-generation": {
"tokenizer": AutoTokenizer,
"pipeline": Text2TextGenerationPipeline,
"model": AutoModelForSeq2SeqLM,
"default": {
// TODO: replace with original
// "model": "google/flan-t5-small",
"model": "Xenova/flan-t5-small",
},
"type": "text",
},
"text-generation": {
"tokenizer": AutoTokenizer,
"pipeline": TextGenerationPipeline,
"model": AutoModelForCausalLM,
"default": {
// TODO: replace with original
// "model": "gpt2",
"model": "Xenova/gpt2",
},
"type": "text",
},
"zero-shot-classification": {
"tokenizer": AutoTokenizer,
"pipeline": ZeroShotClassificationPipeline,
"model": AutoModelForSequenceClassification,
"default": {
// TODO: replace with original
// "model": "typeform/distilbert-base-uncased-mnli",
"model": "Xenova/distilbert-base-uncased-mnli",
},
"type": "text",
},
"audio-classification": {
"pipeline": AudioClassificationPipeline,
"model": AutoModelForAudioClassification,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "superb/wav2vec2-base-superb-ks",
"model": "Xenova/wav2vec2-base-superb-ks",
},
"type": "audio",
},
"zero-shot-audio-classification": {
"tokenizer": AutoTokenizer,
"pipeline": ZeroShotAudioClassificationPipeline,
"model": AutoModel,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "laion/clap-htsat-fused",
"model": "Xenova/clap-htsat-unfused",
},
"type": "multimodal",
},
"automatic-speech-recognition": {
"tokenizer": AutoTokenizer,
"pipeline": AutomaticSpeechRecognitionPipeline,
"model": [AutoModelForSpeechSeq2Seq, AutoModelForCTC],
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "openai/whisper-tiny.en",
"model": "Xenova/whisper-tiny.en",
},
"type": "multimodal",
},
"text-to-audio": {
"tokenizer": AutoTokenizer,
"pipeline": TextToAudioPipeline,
"model": [AutoModelForTextToWaveform, AutoModelForTextToSpectrogram],
"processor": [AutoProcessor, /* Some don't use a processor */ null],
"default": {
// TODO: replace with original
// "model": "microsoft/speecht5_tts",
"model": "Xenova/speecht5_tts",
},
"type": "text",
},
"image-to-text": {
"tokenizer": AutoTokenizer,
"pipeline": ImageToTextPipeline,
"model": AutoModelForVision2Seq,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "nlpconnect/vit-gpt2-image-captioning",
"model": "Xenova/vit-gpt2-image-captioning",
},
"type": "multimodal",
},
"image-classification": {
// no tokenizer
"pipeline": ImageClassificationPipeline,
"model": AutoModelForImageClassification,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "google/vit-base-patch16-224",
"model": "Xenova/vit-base-patch16-224",
},
"type": "multimodal",
},
"image-segmentation": {
// no tokenizer
"pipeline": ImageSegmentationPipeline,
"model": [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation, AutoModelForUniversalSegmentation],
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "facebook/detr-resnet-50-panoptic",
"model": "Xenova/detr-resnet-50-panoptic",
},
"type": "multimodal",
},
"background-removal": {
// no tokenizer
"pipeline": BackgroundRemovalPipeline,
"model": [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation, AutoModelForUniversalSegmentation],
"processor": AutoProcessor,
"default": {
"model": "Xenova/modnet",
},
"type": "image",
},
"zero-shot-image-classification": {
"tokenizer": AutoTokenizer,
"pipeline": ZeroShotImageClassificationPipeline,
"model": AutoModel,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "openai/clip-vit-base-patch32",
"model": "Xenova/clip-vit-base-patch32",
},
"type": "multimodal",
},
"object-detection": {
// no tokenizer
"pipeline": ObjectDetectionPipeline,
"model": AutoModelForObjectDetection,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "facebook/detr-resnet-50",
"model": "Xenova/detr-resnet-50",
},
"type": "multimodal",
},
"zero-shot-object-detection": {
"tokenizer": AutoTokenizer,
"pipeline": ZeroShotObjectDetectionPipeline,
"model": AutoModelForZeroShotObjectDetection,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "google/owlvit-base-patch32",
"model": "Xenova/owlvit-base-patch32",
},
"type": "multimodal",
},
"document-question-answering": {
"tokenizer": AutoTokenizer,
"pipeline": DocumentQuestionAnsweringPipeline,
"model": AutoModelForDocumentQuestionAnswering,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "naver-clova-ix/donut-base-finetuned-docvqa",
"model": "Xenova/donut-base-finetuned-docvqa",
},
"type": "multimodal",
},
"image-to-image": {
// no tokenizer
"pipeline": ImageToImagePipeline,
"model": AutoModelForImageToImage,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "caidas/swin2SR-classical-sr-x2-64",
"model": "Xenova/swin2SR-classical-sr-x2-64",
},
"type": "image",
},
"depth-estimation": {
// no tokenizer
"pipeline": DepthEstimationPipeline,
"model": AutoModelForDepthEstimation,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "Intel/dpt-large",
"model": "Xenova/dpt-large",
},
"type": "image",
},
// This task serves as a useful interface for dealing with sentence-transformers (https://huggingface.co/sentence-transformers).
"feature-extraction": {
"tokenizer": AutoTokenizer,
"pipeline": FeatureExtractionPipeline,
"model": AutoModel,
"default": {
// TODO: replace with original
// "model": "sentence-transformers/all-MiniLM-L6-v2",
"model": "Xenova/all-MiniLM-L6-v2",
},
"type": "text",
},
"image-feature-extraction": {
"processor": AutoProcessor,
"pipeline": ImageFeatureExtractionPipeline,
"model": [AutoModelForImageFeatureExtraction, AutoModel],
"default": {
// TODO: replace with original
// "model": "google/vit-base-patch16-224",
"model": "Xenova/vit-base-patch16-224-in21k",
},
"type": "image",
},
})
// TODO: Add types for TASK_ALIASES
const TASK_ALIASES = Object.freeze({
"sentiment-analysis": "text-classification",
"ner": "token-classification",
// "vqa": "visual-question-answering", // TODO: Add
"asr": "automatic-speech-recognition",
"text-to-speech": "text-to-audio",
// Add for backwards compatibility
"embeddings": "feature-extraction",
});
/**
* @typedef {keyof typeof SUPPORTED_TASKS} TaskType
* @typedef {keyof typeof TASK_ALIASES} AliasType
* @typedef {TaskType | AliasType} PipelineType All possible pipeline types.
* @typedef {{[K in TaskType]: InstanceType<typeof SUPPORTED_TASKS[K]["pipeline"]>}} SupportedTasks A mapping of pipeline names to their corresponding pipeline classes.
* @typedef {{[K in AliasType]: InstanceType<typeof SUPPORTED_TASKS[TASK_ALIASES[K]]["pipeline"]>}} AliasTasks A mapping from pipeline aliases to their corresponding pipeline classes.
* @typedef {SupportedTasks & AliasTasks} AllTasks A mapping from all pipeline names and aliases to their corresponding pipeline classes.
*/
/**
* Utility factory method to build a `Pipeline` object.
*
* @template {PipelineType} T The type of pipeline to return.
* @param {T} task The task defining which pipeline will be returned. Currently accepted tasks are:
* - `"audio-classification"`: will return a `AudioClassificationPipeline`.
* - `"automatic-speech-recognition"`: will return a `AutomaticSpeechRecognitionPipeline`.
* - `"depth-estimation"`: will return a `DepthEstimationPipeline`.
* - `"document-question-answering"`: will return a `DocumentQuestionAnsweringPipeline`.
* - `"feature-extraction"`: will return a `FeatureExtractionPipeline`.
* - `"fill-mask"`: will return a `FillMaskPipeline`.
* - `"image-classification"`: will return a `ImageClassificationPipeline`.
* - `"image-segmentation"`: will return a `ImageSegmentationPipeline`.
* - `"image-to-text"`: will return a `ImageToTextPipeline`.
* - `"object-detection"`: will return a `ObjectDetectionPipeline`.
* - `"question-answering"`: will return a `QuestionAnsweringPipeline`.
* - `"summarization"`: will return a `SummarizationPipeline`.
* - `"text2text-generation"`: will return a `Text2TextGenerationPipeline`.
* - `"text-classification"` (alias "sentiment-analysis" available): will return a `TextClassificationPipeline`.
* - `"text-generation"`: will return a `TextGenerationPipeline`.
* - `"token-classification"` (alias "ner" available): will return a `TokenClassificationPipeline`.
* - `"translation"`: will return a `TranslationPipeline`.
* - `"translation_xx_to_yy"`: will return a `TranslationPipeline`.
* - `"zero-shot-classification"`: will return a `ZeroShotClassificationPipeline`.
* - `"zero-shot-audio-classification"`: will return a `ZeroShotAudioClassificationPipeline`.
* - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`.
* - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`.
* @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used.
* @param {import('./utils/hub.js').PretrainedModelOptions} [options] Optional parameters for the pipeline.
* @returns {Promise<AllTasks[T]>} A Pipeline object for the specified task.
* @throws {Error} If an unsupported pipeline is requested.
*/
export async function pipeline(
task,
model = null,
{
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
device = null,
dtype = null,
subfolder = 'onnx',
use_external_data_format = null,
model_file_name = null,
session_options = {},
} = {}
) {
// Helper method to construct pipeline
// Apply aliases
// @ts-ignore
task = TASK_ALIASES[task] ?? task;
// Get pipeline info
const pipelineInfo = SUPPORTED_TASKS[task.split('_', 1)[0]];
if (!pipelineInfo) {
throw Error(`Unsupported pipeline: ${task}. Must be one of [${Object.keys(SUPPORTED_TASKS)}]`)
}
// Use model if specified, otherwise, use default
if (!model) {
model = pipelineInfo.default.model
console.log(`No model specified. Using default model: "${model}".`);
}
const pretrainedOptions = {
progress_callback,
config,
cache_dir,
local_files_only,
revision,
device,
dtype,
subfolder,
use_external_data_format,
model_file_name,
session_options,
}
const classes = new Map([
['tokenizer', pipelineInfo.tokenizer],
['model', pipelineInfo.model],
['processor', pipelineInfo.processor],
]);
// Load model, tokenizer, and processor (if they exist)
const results = await loadItems(classes, model, pretrainedOptions);
results.task = task;
dispatchCallback(progress_callback, {
'status': 'ready',
'task': task,
'model': model,
});
const pipelineClass = pipelineInfo.pipeline;
return new pipelineClass(results);
}
/**
* Helper function to get applicable model, tokenizer, or processor classes for a given model.
* @param {Map<string, any>} mapping The mapping of names to classes, arrays of classes, or null.
* @param {string} model The name of the model to load.
* @param {import('./utils/hub.js').PretrainedOptions} pretrainedOptions The options to pass to the `from_pretrained` method.
* @private
*/
async function loadItems(mapping, model, pretrainedOptions) {
const result = Object.create(null);
/**@type {Promise[]} */
const promises = [];
for (const [name, cls] of mapping.entries()) {
if (!cls) continue;
/**@type {Promise} */
let promise;
if (Array.isArray(cls)) {
promise = new Promise(async (resolve, reject) => {
let e;
for (const c of cls) {