processors/vertexgenai.py (260 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .base import Processor, NotConfiguredException from helpers.base import get_user_agent import json import google.auth import google.oauth2.id_token from google.auth.transport.requests import AuthorizedSession import urllib class VertexgenaiProcessor(Processor): """ Vertex AI Generative AI processor. Args: region (str): Endpoint to use. modelId (str): Deployed model to use. project (str, optional): Google Cloud project ID. method (str, optional): Method to call, by default: "predict" returnErrors (bool, optional): Set to true to return errors callFunctions (dict, optional): URLs for functions. request (dict): Request. """ def get_default_config_key(): return 'vertexgenai' def call_function(self, name, params): method = 'GET' if 'method' in params: method = params['method'].upper() headers = {} if 'headers' in params: for header in params['headers']: headers[header['name'].lower()] = header['value'] body = '' if 'body' in params: if isinstance(params['body'], dict): body = json.dumps(params['body']) else: body = params['body'] loggable_headers = headers if 'authorization' in loggable_headers: del loggable_headers['authorization'] if 'x-serverless-authorization' in loggable_headers: del loggable_headers['x-serverless-authorization'] if 'api-key' in loggable_headers: del loggable_headers['api-key'] if 'x-api-key' in loggable_headers: del loggable_headers['x-api-key'] if 'proxy-authorization' in loggable_headers: del loggable_headers['proxy-authorization'] self.logger.info('Calling function: %s' % (name), extra={ 'url': params['url'], 'method': method, 'body_length': len(body), 'headers': loggable_headers, 'id_token': True if 'idToken' in params else False }) id_token = None audience = None if 'idToken' in params and params['idToken']: if 'audience' in params: audience = params['audience'] else: audience = params['url'] auth_request = google.auth.transport.requests.Request() id_token = google.oauth2.id_token.fetch_id_token( auth_request, audience) headers['authorization'] = 'Bearer %s' % (id_token) req = urllib.request.Request(params['url'], headers=headers, method=method) response = urllib.request.urlopen(req, data=body.encode('utf-8')) if response.status < 200 or response.status >= 400: self.logger.error('Error calling function: %s' % (name), extra={ 'status_code': response.status, 'response': response.data.decode('utf-8') }) return (response.headers, json.loads(response.read().decode('utf-8'))) def process(self, output_var='vertexgenai'): if 'location' not in self.config: raise NotConfiguredException('No location specified specified.') if 'modelId' not in self.config: raise NotConfiguredException('No model ID specified.') if 'request' not in self.config: raise NotConfiguredException('No request specified.') credentials, credentials_project_id = google.auth.default() project = self.config[ 'project'] if 'project' in self.config else credentials_project_id if not project: project = credentials.quota_project_id method = self._jinja_expand_string( self.config['method'], 'method') if 'method' in self.config else 'predict' location = self._jinja_expand_string(self.config['location'], 'location') model_id = self._jinja_expand_string(self.config['modelId'], 'modelId') api_path = 'https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s' % ( location, project, location, model_id, method) return_errors = False if 'returnErrors' in self.config: return_errors = self._jinja_expand_bool(self.config['returnErrors'], 'return_errors') request = self._jinja_expand_dict_all_expr(self.config['request'], 'request') headers = { 'User-Agent': get_user_agent(), 'Content-type': 'application/json; charset=utf-8', } if method == 'predict': # Messages must be USER, AI, USER, AI.. etc, so filter them # to make life easier. if 'instances' in request: new_instances = [] for instance in request['instances']: new_instance = instance if 'messages' in instance: new_messages = [] last_author = None for message in reversed(instance['messages']): if message['author'] != last_author: new_messages.append(message) last_author = message['author'] new_instance['messages'] = list(reversed(new_messages)) new_instances.append(new_instance) request['instances'] = new_instances else: # Messages must be USER, AI, USER, AI.. etc, so filter them # to make life easier. if 'contents' in request: new_contents = [] last_role = None for content in list(reversed(request['contents'])): if not last_role: last_role = content['role'] new_contents.insert(0, content) else: if content['role'] != last_role: new_contents.insert(0, content) last_role = content['role'] request['contents'] = new_contents self.logger.debug('Calling Vertex AI %s' % (method), extra={ 'request_body': request, 'api_url': api_path }) request_body = json.dumps(request) authed_session = AuthorizedSession(credentials) response = authed_session.post(api_path, data=request_body, headers=headers) try: response.raise_for_status() except Exception as e: self.logger.error('Error calling %s: %s' % (e.request.url, e), extra={'response': e.response.text}) if return_errors: try: response_json = response.json() for err in response_json: if isinstance( err, dict ) and 'error' in err and 'message' in err['error']: return { output_var: { 'error': err['error']['message'] } } else: return {output_var: {'error': err}} except Exception as _: return {output_var: {'error': e.response.text}} raise e try: response_json = response.json() except Exception as e: self.logger.error('Response was not JSON from %s: %s' % (e.request.url, e), extra={'response': e.response.text}) if return_errors: return {output_var: {'error': e.response.text}} raise e if isinstance(response_json, dict): response_json = [response_json] # Check if functions need to be called if 'callFunctions' in self.config: function_calls = {} function_contents = {} for response in response_json: if 'candidates' in response: for candidate in response['candidates']: if 'content' in candidate: if 'parts' in candidate['content']: parts = candidate['content']['parts'] for part in parts: if 'functionCall' in part: function_name = part['functionCall'][ 'name'] function_contents[ function_name] = candidate[ 'content'] args = part['functionCall'][ 'args'] if 'args' in part[ 'functionCall'] else {} self.logger.info( 'Vertex wants us to call function %s.' % (function_name), extra={'function_args': args}) function_calls[function_name] = args else: self.logger.warn( 'No parts in Vertex response candidate content.', extra={'candidate': candidate}) else: self.logger.warn( 'No content in Vertex response candidate.', extra={'candidate': candidate}) else: self.logger.warn('No candidates in Vertex response.', extra={'response_part': response}) function_responses = {} jinja_globals = self.jinja_environment.globals for name, args in function_calls.items(): for k, v in args.items(): self.jinja_environment.globals[k] = v defined_functions = self._jinja_expand_dict_all( self.config['callFunctions'], 'call_functions') if name in defined_functions: function_responses[name] = self.call_function( name, defined_functions[name]) else: self.logger.error( 'No function configuration specified for: %s' % (name), extra={'function_name': name}) if len(function_responses) > 0: self.jinja_environment.globals = jinja_globals for name, result in function_responses.items(): request['contents'].append(function_contents[name]) request['contents'].append({ 'role': 'MODEL', 'parts': [{ 'functionResponse': { 'name': name, 'response': result[1], } }] }) self.logger.debug( 'Re-doing Vertex request after adding function responses.', extra={'request': request}) request_body = json.dumps(request) authed_session = AuthorizedSession(credentials) response = authed_session.post(api_path, data=request_body, headers=headers) try: response.raise_for_status() except Exception as e: if return_errors: response_json = response.json() for err in response_json: if 'error' in err and 'message' in err['error']: return { output_var: { 'error': err['error']['message'] } } self.logger.error('Error calling %s: %s' % (e.request.url, e), extra={'response': e.response.text}) raise e response_json = response.json() if isinstance(response_json, dict): response_json = [response_json] return { output_var: response_json, }