dialogflow-cx/dialogflow_sample.py (150 lines of code) (raw):

# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for the base class for all Dialogflow CX samples.""" import time import uuid import google.api_core.exceptions import google.cloud.dialogflowcx as cx class UnexpectedResponseFailure(AssertionError): """Exception to raise when a test case fails""" class TestCaseFailure(AssertionError): """Exception to raise when a test case fails""" class SessionParametersFailure(AssertionError): """Exception to raise when a test case fails""" class DialogflowSample: """Base class for samples.""" def __init__(self) -> None: self._agent_delegator = None self._auth_delegator = None self._credentials = None self._test_cases_client = None self._start_flow_delegator = None self._session_delegator = None def set_auth_delegator(self, auth_delegator): """Sets the AuthDelegator for the sample.""" self._auth_delegator = auth_delegator def set_agent_delegator(self, agent_delegator): """Sets the AgentDelegator for the sample.""" self._agent_delegator = agent_delegator def set_session_delegator(self, session_delegator): """Sets the SessionDelegator for the sample.""" self._session_delegator = session_delegator def set_start_flow_delegator(self, start_flow_delegator): """Sets the AgentDelegator for the sample.""" self._start_flow_delegator = start_flow_delegator def set_credentials(self, credentials): """Sets the AgentDelegator for the sample.""" self._credentials = credentials @property def auth_delegator(self): """Accesses the auth_delegator for the sample.""" return self._auth_delegator @property def agent_delegator(self): """Accesses the agent_delegator for the sample.""" return self._agent_delegator @property def start_flow_delegator(self): """Accesses the start_flow_delegator for the sample.""" return self._start_flow_delegator @property def session_delegator(self): """Accesses the start_flow_delegator for the sample.""" return self._session_delegator @property def credentials(self): """Accesses the agent_delegator for the sample.""" return self._credentials @property def project_id(self): """Accesses the project ID for the sample.""" return self.auth_delegator.project_id @property def location(self): """Accesses the location ID for the sample.""" return self.auth_delegator.location @property def start_flow(self): """Accesses the start_flow for the sample.""" return self.agent_delegator.start_flow @property def client_options(self): """Accesses the client_options for the delegator.""" return {"api_endpoint": f"{self.location}-dialogflow.googleapis.com"} @property def test_cases_client(self): """Accesses the test_case_delegators for the sample.""" if self._test_cases_client is None: self._test_cases_client = cx.TestCasesClient( client_options=self.client_options, credentials=self.credentials, ) return self._test_cases_client def setup(self, wait=0): """Set up sample. Especially, train the start flow.""" request = cx.TrainFlowRequest(name=self.start_flow_delegator.flow.name) lro = self.start_flow_delegator.client.train_flow(request=request) time.time() while lro.running(): time.sleep(0.1) time.sleep(wait) # pylint: disable=too-many-arguments def run( self, user_text_list, session_id=None, wait=1, parameters=None, current_page=None, quiet=False, ): """Runs a conversation with this agent.""" time.sleep(wait) if parameters is None: parameters = {} if not session_id: session_id = str(uuid.uuid1()) for text in user_text_list: if not quiet: print("User: ") print(f" Text: {text}") print(f" Starting Parameters: {parameters}") print(f" Page: {current_page}") responses, current_page, parameters = self.session_delegator.detect_intent( text, parameters=parameters, current_page=current_page, session_id=session_id, ) if not quiet: print(" Agent:") for response in responses: print(f" Text: {response}") print(f" Ending Parameters: {parameters}") print(f" Ending Page: {current_page}") def create_test_case(self, display_name, test_case_conversation_turns, flow=None): """Create a test case.""" if flow is None: flow = self.start_flow try: test_case = self.test_cases_client.create_test_case( parent=self.agent_delegator.agent.name, test_case=cx.TestCase( display_name=display_name, test_case_conversation_turns=test_case_conversation_turns, test_config=cx.TestConfig(flow=flow), ), ) except google.api_core.exceptions.AlreadyExists: request = cx.ListTestCasesRequest(parent=self.agent_delegator.agent.name) for curr_test_case in self.test_cases_client.list_test_cases( request=request ): if curr_test_case.display_name == display_name: request = cx.GetTestCaseRequest( name=curr_test_case.name, ) test_case = self.test_cases_client.get_test_case(request=request) break return test_case def run_test_case(self, test_case, expected_session_parameters): """Runs a test case using TestCases API.""" lro = self.test_cases_client.run_test_case( request=cx.RunTestCaseRequest(name=test_case.name) ) while lro.running(): time.sleep(0.1) result = lro.result().result agent_response_differences = [ conversation_turn.virtual_agent_output.differences for conversation_turn in result.conversation_turns ] if any(agent_response_differences): raise UnexpectedResponseFailure(agent_response_differences) final_session_parameters = [] for conversation_turn in result.conversation_turns: if conversation_turn.virtual_agent_output.session_parameters: final_session_parameters.append( dict(conversation_turn.virtual_agent_output.session_parameters) ) else: final_session_parameters.append({}) if expected_session_parameters != final_session_parameters: raise SessionParametersFailure( f"{expected_session_parameters!r} != {final_session_parameters!r}" ) if result.test_result != cx.TestResult.PASSED: raise TestCaseFailure