firestore-semantic-search/functions/src/common/palm_embeddings.ts (63 lines of code) (raw):

/** * Copyright 2023 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import {helpers, v1} from '@google-cloud/aiplatform'; import config from '../config'; let client: v1.PredictionServiceClient; const endpoint = `projects/${config.projectId}/locations/${config.location}/publishers/google/models/${config.palmModel}`; const initializePaLMClient = async () => { const t0 = performance.now(); // here location is hard-coded, following https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings#generative-ai-get-text-embedding-nodejs const clientOptions = { apiEndpoint: 'us-central1-aiplatform.googleapis.com', }; client = new v1.PredictionServiceClient(clientOptions); const duration = performance.now() - t0; console.log(`Initialized client. This took ${duration}ms`); }; /** * Generates embeddings for single batch of sentences using PaLM embedding model. * * @param text a string or array of strings to be embedded. * @param key the key of the text in the document. * @returns an array of arrays containing 512 numbers representing the embedding of the text. */ async function embedBatchPaLM(batch: string[]): Promise<number[][]> { const instances = batch.map(text => helpers.toValue({content: text}) ) as protobuf.common.IValue[]; const parameters = helpers.toValue({}); const [response] = await client.predict({ endpoint, instances, parameters, }); if (!response || !response.predictions || response.predictions.length === 0) throw new Error('Error with embedding'); const predictionValues = response.predictions as protobuf.common.IValue[]; const predictions = predictionValues.map(helpers.fromValue) as { embeddings: {values: number[]}; }[]; if ( predictions.some( prediction => !prediction.embeddings || !prediction.embeddings.values ) ) { throw new Error('Error with embedding'); } const embeddings = predictions.map( prediction => prediction.embeddings.values ); return embeddings; } /** * Batches and embeddings for a given array of sentences using PaLM embedding model. * * @param text a string or array of strings to be embedded. * @param key the key of the text in the document. * @returns an array of arrays containing 512 numbers representing the embedding of the text. */ async function getEmbeddingsPaLM(text: string | string[]): Promise<number[][]> { if (!client && (typeof text !== 'string' || text.length !== 0)) { await initializePaLMClient(); } if (typeof text === 'string') text = [text]; // chunk into batches of 5 (the current limit of the PaLM API) const batchSize = 5; const batches = []; for (let i = 0; i < text.length; i += batchSize) { batches.push(text.slice(i, i + batchSize)); } const t0 = performance.now(); const embeddingBatches = await Promise.all( batches.map(async batch => { return embedBatchPaLM(batch); }) ); const embeddings = embeddingBatches.flat(); const duration = performance.now() - t0; console.log(`Processed embeddings. This took ${duration}ms`); // const embeddings = await client.embedText(text.length ? text : [text]); return embeddings; } export default getEmbeddingsPaLM;