processors/vertexai.py (92 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import Processor, NotConfiguredException
from helpers.base import get_user_agent
import json
import google.auth
from google.auth.transport.requests import AuthorizedSession
class VertexaiProcessor(Processor):
"""
Vertex AI processor.
Args:
region (str): Endpoint to use.
mode (str): search (Vertex AI Search)
method (str, optional): Method to call (defaults to "search").
project (str, optional): Google Cloud project ID.
location (str): Location for Vertex AI.
collection (str, optional): Collection (defaults to "default_collection")
engineId (str): Engine ID.
datastoreId (str): Data store ID (either this or engineId).
servingConfig (str, optional): Serving configuration (defaults to "default_config").
returnErrors (bool, optional): Set to true to return errors
apiVersion (str, optional): API version, defaults to "v1".
request (dict): Request.
"""
def get_default_config_key():
return 'vertexai'
def process(self, output_var='vertexai'):
if 'location' not in self.config:
raise NotConfiguredException('No location specified specified.')
if 'mode' not in self.config:
raise NotConfiguredException('No mode specified.')
if 'request' not in self.config:
raise NotConfiguredException('No request specified.')
credentials, credentials_project_id = google.auth.default()
project = self.config[
'project'] if 'project' in self.config else credentials_project_id
if not project:
project = credentials.quota_project_id
mode = self._jinja_expand_string(self.config['mode'], 'mode')
location = self._jinja_expand_string(self.config['location'],
'location')
collection = self._jinja_expand_string(
self.config['collection'], 'collection'
) if 'collection' in self.config else 'default_collection'
engine_id = None
datastore_id = None
if 'engineId' in self.config:
engine_id = self._jinja_expand_string(self.config['engineId'],
'engine_id')
else:
datastore_id = self._jinja_expand_string(self.config['datastoreId'],
'datastore_id')
serving_config = self._jinja_expand_string(
self.config['servingConfig'], 'serving_config'
) if 'servingConfig' in self.config else 'default_config'
api_version = self._jinja_expand_string(
self.config['apiVersion']) if 'apiVersion' in self.config else 'v1'
method = self._jinja_expand_string(
self.config['method'],
'method') if 'method' in self.config else 'search'
if mode == 'search':
if engine_id:
api_path = 'https://%s-discoveryengine.googleapis.com/%s/projects/%s/locations/%s/collections/%s/engineId/%s/servingConfigs/%s:%s' % (
api_version, location, project, location, collection,
engine_id, serving_config, method)
else:
api_path = 'https://%s-discoveryengine.googleapis.com/%s/projects/%s/locations/%s/collections/%s/dataStores/%s/servingConfigs/%s:%s' % (
location, api_version, project, location, collection,
datastore_id, serving_config, method)
return_errors = False
if 'returnErrors' in self.config:
return_errors = self._jinja_expand_bool(
self.config['returnErrors'], 'return_errors')
request = self._jinja_expand_dict_all_expr(self.config['request'],
'request')
headers = {
'User-Agent': get_user_agent(),
'Content-type': 'application/json; charset=utf-8',
}
self.logger.debug('Calling Vertex AI %s:%s' % (mode, method),
extra={
'request_body': request,
'api_url': api_path
})
request_body = json.dumps(request)
authed_session = AuthorizedSession(credentials)
response = authed_session.post(api_path,
data=request_body,
headers=headers)
try:
response.raise_for_status()
except Exception as e:
if return_errors:
response_json = response.json()
for err in response_json:
if isinstance(
err, dict
) and 'error' in err and 'message' in err['error']:
return {
output_var: {
'error': err['error']['message']
}
}
self.logger.error('Error calling %s: %s' % (e.request.url, e),
extra={'response': e.response.text})
raise e
response_json = response.json()
return {
output_var: response_json,
}