maga_transformer/openai/renderers/qwen_agent/llm/qwen_dashscope.py (197 lines of code) (raw):

import os from http import HTTPStatus from pprint import pformat from typing import Dict, Iterator, List, Optional, Union, Literal import dashscope from qwen_agent.llm.base import ModelServiceError, register_llm from qwen_agent.llm.schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, SYSTEM, USER, Message from qwen_agent.llm.text_base import BaseTextChatModel from qwen_agent.log import logger # region youmi import copy from qwen_agent.settings import DEFAULT_MAX_INPUT_TOKENS from qwen_agent.utils.utils import has_chinese_messages, merge_generate_cfgs from qwen_agent.llm.base import _truncate_input_messages_roughly # end region @register_llm('qwen_dashscope') class QwenChatAtDS(BaseTextChatModel): def __init__(self, cfg: Optional[Dict] = None): super().__init__(cfg) self.model = self.model or 'qwen-max' initialize_dashscope(cfg) def _chat_stream( self, messages: List[Message], delta_stream: bool, generate_cfg: dict, ) -> Iterator[List[Message]]: messages = [msg.model_dump() for msg in messages] logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}') response = dashscope.Generation.call( self.model, messages=messages, # noqa result_format='message', stream=True, **generate_cfg) if delta_stream: return self._delta_stream_output(response) else: return self._full_stream_output(response) def _chat_no_stream( self, messages: List[Message], generate_cfg: dict, ) -> List[Message]: messages = [msg.model_dump() for msg in messages] logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}') response = dashscope.Generation.call( self.model, messages=messages, # noqa result_format='message', stream=False, **generate_cfg) if response.status_code == HTTPStatus.OK: return [Message(ASSISTANT, response.output.choices[0].message.content)] else: raise ModelServiceError(code=response.code, message=response.message) def _continue_assistant_response( self, messages: List[Message], generate_cfg: dict, stream: bool, ) -> Iterator[List[Message]]: prompt = self._build_text_completion_prompt(messages) logger.debug(f'LLM Input:\n{pformat(prompt, indent=2)}') response = dashscope.Generation.call( self.model, prompt=prompt, # noqa result_format='message', stream=True, use_raw_prompt=True, **generate_cfg) it = self._full_stream_output(response) if stream: return it # streaming the response else: *_, final_response = it # return the final response without streaming return final_response # region youmi generate completion prompt def generate_completion_prompt( self, messages: List[Union[Message, Dict]], functions: Optional[List[Dict]] = None, extra_generate_cfg: Optional[Dict] = None ) -> str: """ copy from LLM chat interface. Args: messages: Inputted messages. functions: Inputted functions for function calling. OpenAI format supported. Returns: the generated prompt """ generate_cfg = merge_generate_cfgs(base_generate_cfg=self.generate_cfg, new_generate_cfg=extra_generate_cfg) if 'lang' in generate_cfg: lang: Literal['en', 'zh'] = generate_cfg.pop('lang') else: lang: Literal['en', 'zh'] = 'zh' if has_chinese_messages(messages) else 'en' messages = copy.deepcopy(messages) _return_message_type = 'dict' new_messages = [] for msg in messages: if isinstance(msg, dict): new_messages.append(Message(**msg)) else: new_messages.append(msg) _return_message_type = 'message' messages = new_messages if messages[0].role != SYSTEM: messages = [Message(role=SYSTEM, content=DEFAULT_SYSTEM_MESSAGE)] + messages # Not precise. It's hard to estimate tokens related with function calling and multimodal items. messages = _truncate_input_messages_roughly( messages=messages, max_tokens=generate_cfg.pop('max_input_tokens', DEFAULT_MAX_INPUT_TOKENS), ) messages = self._preprocess_messages(messages, lang=lang) if 'function_choice' in generate_cfg: fn_choice = generate_cfg['function_choice'] valid_fn_choices = [f.get('name', f.get('name_for_model', None)) for f in (functions or [])] valid_fn_choices = ['auto', 'none'] + [f for f in valid_fn_choices if f] if fn_choice not in valid_fn_choices: raise ValueError(f'The value of function_choice must be one of the following: {valid_fn_choices}. ' f'But function_choice="{fn_choice}" is received.') if fn_choice == 'auto': del generate_cfg['function_choice'] if fn_choice == 'none': raise NotImplementedError('Not implemented function_choice="none" yet.') # TODO: if functions: fncall_mode = True else: fncall_mode = False for k in ['parallel_function_calls', 'function_choice']: if k in generate_cfg: del generate_cfg[k] if fncall_mode: messages = self._prepend_fncall_system(messages, functions, lang=lang) prompt = self._build_text_completion_prompt(messages) return prompt else: prompt = self._build_text_completion_prompt(messages) return prompt # end region @staticmethod def _build_text_completion_prompt(messages: List[Message]) -> str: im_start = '<|im_start|>' im_end = '<|im_end|>' if messages[0].role == SYSTEM: sys = messages[0].content assert isinstance(sys, str) prompt = f'{im_start}{SYSTEM}\n{sys}{im_end}' else: prompt = f'{im_start}{SYSTEM}\n{DEFAULT_SYSTEM_MESSAGE}{im_end}' if messages[-1].role != ASSISTANT: messages.append(Message(ASSISTANT, '')) for msg in messages: assert isinstance(msg.content, str) if msg.role == USER: query = msg.content.lstrip('\n').rstrip() prompt += f'\n{im_start}{USER}\n{query}{im_end}' elif msg.role == ASSISTANT: response = msg.content.lstrip('\n').rstrip() prompt += f'\n{im_start}{ASSISTANT}\n{response}{im_end}' assert prompt.endswith(im_end) prompt = prompt[:-len(im_end)] return prompt @staticmethod def _delta_stream_output(response) -> Iterator[List[Message]]: last_len = 0 delay_len = 5 in_delay = False text = '' for chunk in response: if chunk.status_code == HTTPStatus.OK: text = chunk.output.choices[0].message.content if (len(text) - last_len) <= delay_len: in_delay = True continue else: in_delay = False real_text = text[:-delay_len] now_rsp = real_text[last_len:] yield [Message(ASSISTANT, now_rsp)] last_len = len(real_text) else: raise ModelServiceError(code=chunk.code, message=chunk.message) if text and (in_delay or (last_len != len(text))): yield [Message(ASSISTANT, text[last_len:])] @staticmethod def _full_stream_output(response) -> Iterator[List[Message]]: for chunk in response: if chunk.status_code == HTTPStatus.OK: yield [Message(ASSISTANT, chunk.output.choices[0].message.content)] else: raise ModelServiceError(code=chunk.code, message=chunk.message) def initialize_dashscope(cfg: Optional[Dict] = None) -> None: cfg = cfg or {} api_key = cfg.get('api_key', '') base_http_api_url = cfg.get('base_http_api_url', None) base_websocket_api_url = cfg.get('base_websocket_api_url', None) if not api_key: api_key = os.getenv('DASHSCOPE_API_KEY', 'EMPTY') if not base_http_api_url: base_http_api_url = os.getenv('DASHSCOPE_HTTP_URL', None) if not base_websocket_api_url: base_websocket_api_url = os.getenv('DASHSCOPE_WEBSOCKET_URL', None) api_key = api_key.strip() dashscope.api_key = api_key if base_http_api_url is not None: dashscope.base_http_api_url = base_http_api_url.strip() if base_websocket_api_url is not None: dashscope.base_websocket_api_url = base_websocket_api_url.strip()