async _call()

in src/pipelines.js [1164:1227]


    async _call(texts, candidate_labels, {
        hypothesis_template = "This example is {}.",
        multi_label = false,
    } = {}) {

        const isBatched = Array.isArray(texts);
        if (!isBatched) {
            texts = [/** @type {string} */ (texts)];
        }
        if (!Array.isArray(candidate_labels)) {
            candidate_labels = [candidate_labels];
        }

        // Insert labels into hypothesis template
        const hypotheses = candidate_labels.map(
            x => hypothesis_template.replace('{}', x)
        );

        // How to perform the softmax over the logits:
        //  - true:  softmax over the entailment vs. contradiction dim for each label independently
        //  - false: softmax the "entailment" logits over all candidate labels
        const softmaxEach = multi_label || candidate_labels.length === 1;

        /** @type {ZeroShotClassificationOutput[]} */
        const toReturn = [];
        for (const premise of texts) {
            const entails_logits = [];

            for (const hypothesis of hypotheses) {
                const inputs = this.tokenizer(premise, {
                    text_pair: hypothesis,
                    padding: true,
                    truncation: true,
                })
                const outputs = await this.model(inputs)

                if (softmaxEach) {
                    entails_logits.push([
                        outputs.logits.data[this.contradiction_id],
                        outputs.logits.data[this.entailment_id]
                    ])
                } else {
                    entails_logits.push(outputs.logits.data[this.entailment_id])
                }
            }

            /** @type {number[]} */
            const scores = softmaxEach
                ? entails_logits.map(x => softmax(x)[1])
                : softmax(entails_logits);

            // Sort by scores (desc) and return scores with indices
            const scores_sorted = scores
                .map((x, i) => [x, i])
                .sort((a, b) => (b[0] - a[0]));

            toReturn.push({
                sequence: premise,
                labels: scores_sorted.map(x => candidate_labels[x[1]]),
                scores: scores_sorted.map(x => x[0]),
            });
        }
        return isBatched ? toReturn : toReturn[0];
    }