firestore-palm-summarize-text/functions/src/generator.ts (187 lines of code) (raw):

/** * Copyright 2019 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import {TextServiceClient} from '@google-ai/generativelanguage'; import {helpers, v1, protos} from '@google-cloud/aiplatform'; import * as logs from './logs'; import {GoogleAuth} from 'google-auth-library'; import {GLGenerateTextRequest} from './types'; import config from './config'; export interface TextGeneratorOptions { model?: string; temperature?: number; candidateCount?: number; topP?: number; topK?: number; maxOutputTokens?: number; instruction?: string; generativeSafetySettings?: GLGenerateTextRequest['safetySettings']; } export type TextGeneratorRequestOptions = Omit< GLGenerateTextRequest, 'prompt' | 'model' >; type VertexPredictResponse = protos.google.cloud.aiplatform.v1beta1.IPredictResponse; export class TextGenerator { private generativeClient: TextServiceClient | null = null; private vertexClient: v1.PredictionServiceClient | null = null; private endpoint: string; instruction?: string; context?: string; model: string = config.model; temperature?: number; candidateCount?: number; topP?: number; topK?: number; maxOutputTokens: number; generativeSafetySettings: TextGeneratorRequestOptions['safetySettings']; constructor(options: TextGeneratorOptions = {}) { this.temperature = options.temperature; this.topP = options.topP; this.topK = options.topK; this.maxOutputTokens = options.maxOutputTokens || 1024; this.candidateCount = options.candidateCount; this.instruction = options.instruction; this.generativeSafetySettings = options.generativeSafetySettings || []; if (options.model) this.model = options.model; this.endpoint = `projects/${config.projectId}/locations/${config.location}/publishers/google/models/${this.model}`; if (config.provider === 'vertex') { // here location is hard-coded, following https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings#generative-ai-get-text-embedding-nodejs const clientOptions = { apiEndpoint: 'us-central1-aiplatform.googleapis.com', }; this.vertexClient = new v1.PredictionServiceClient(clientOptions); } else { if (config.apiKey) { logs.usingAPIKey(); const authClient = new GoogleAuth().fromAPIKey(config.apiKey); this.generativeClient = new TextServiceClient({ authClient, }); } else { logs.usingADC(); const auth = new GoogleAuth({ scopes: [ 'https://www.googleapis.com/auth/userinfo.email', 'https://www.googleapis.com/auth/generative-language', ], }); this.generativeClient = new TextServiceClient({ auth, }); } } } private extractVertexCandidateResponse(result: VertexPredictResponse) { if (!result.predictions || !result.predictions.length) { throw new Error('No predictions returned from Vertex AI.'); } const predictionValue = result.predictions[0] as protobuf.common.IValue; const vertexPrediction = helpers.fromValue(predictionValue); return convertToTextGeneratorResponse(vertexPrediction as VertexPrediction); } async generate( promptText: string, options: TextGeneratorRequestOptions = {} ): Promise<TextGeneratorResponse> { if (config.provider === 'vertex') { if (!this.vertexClient) { throw new Error('Vertex client not initialized.'); } const prompt = { prompt: promptText, }; const instanceValue = helpers.toValue(prompt); const instances = [instanceValue!]; const temperature = options.temperature || this.temperature; const topP = options.topP || this.topP; const topK = options.topK || this.topK; const maxOutputTokens = options.maxOutputTokens || this.maxOutputTokens; const parameter: Record<string, string | number> = {}; // We have to set these conditionally or they get nullified and the request fails with a serialization error. if (temperature) { parameter.temperature = temperature; } if (topP) { parameter.top_p = topP; } if (topK) { parameter.top_k = topK; } parameter.maxOutputTokens = maxOutputTokens; const parameters = helpers.toValue(parameter); const request = { endpoint: this.endpoint, instances, parameters, }; const [result] = await this.vertexClient.predict(request); return this.extractVertexCandidateResponse(result); } const request = { prompt: { text: promptText, }, model: `models/${this.model}`, ...options, safetySettings: this.generativeSafetySettings, }; if (!this.generativeClient) { throw new Error('Generative Language Client not initialized.'); } const [result] = await this.generativeClient.generateText(request); return convertToTextGeneratorResponse(result as GenerativePrediction); } } type VertexPrediction = { safetyAttributes?: { blocked: boolean; categories: string[]; scores: number[]; }; content?: string; }; type GenerativePrediction = { candidates: {output: string}[]; filters?: {reason: string}[]; safetyFeedback?: { rating: Record<string, any>; setting: Record<string, any>; }[]; }; type TextGeneratorResponse = { candidates: string[]; safetyMetadata?: { blocked: boolean; [key: string]: any; }; }; function convertToTextGeneratorResponse( prediction: VertexPrediction | GenerativePrediction ): TextGeneratorResponse { // if it's generative language if ('candidates' in prediction) { const {candidates, filters, safetyFeedback} = prediction; const blocked = !!filters && filters.length > 0; const safetyMetadata = { blocked, safetyFeedback, }; if (!candidates.length && !blocked) { throw new Error('No candidates returned from the Generative API.'); } return { candidates: candidates.map(candidate => candidate.output), safetyMetadata, }; } else { // provider will be vertex const {content, safetyAttributes} = prediction; const blocked = !!safetyAttributes && !!safetyAttributes.blocked; const safetyMetadata = { blocked, safetyAttributes, }; if (!content && !blocked) { throw new Error('No content returned from the Vertex PaLM API.'); } return { candidates: blocked ? [] : [content!], safetyMetadata, }; } }