123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483 |
- import copy
- import random
- import time
- from abc import ABC, abstractmethod
- from pprint import pformat
- from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
- from qwen_agent.llm.schema import DEFAULT_SYSTEM_MESSAGE, SYSTEM, USER, Message
- from qwen_agent.log import logger
- from qwen_agent.settings import DEFAULT_MAX_INPUT_TOKENS
- from qwen_agent.utils.tokenization_qwen import tokenizer
- from qwen_agent.utils.utils import (extract_text_from_message, format_as_multimodal_message, format_as_text_message,
- has_chinese_messages, merge_generate_cfgs)
- LLM_REGISTRY = {}
- def register_llm(model_type):
- def decorator(cls):
- LLM_REGISTRY[model_type] = cls
- return cls
- return decorator
- class ModelServiceError(Exception):
- def __init__(self,
- exception: Optional[Exception] = None,
- code: Optional[str] = None,
- message: Optional[str] = None):
- if exception is not None:
- super().__init__(exception)
- else:
- super().__init__(f'\nError code: {code}. Error message: {message}')
- self.exception = exception
- self.code = code
- self.message = message
- class BaseChatModel(ABC):
- """The base class of LLM"""
- @property
- def support_multimodal_input(self) -> bool:
- # Does the model support multimodal input natively? It affects how we preprocess the input.
- return False
- @property
- def support_multimodal_output(self) -> bool:
- # Does the model generate multimodal outputs beyond texts? It affects how we post-process the output.
- return False
- def __init__(self, cfg: Optional[Dict] = None):
- cfg = cfg or {}
- self.model = cfg.get('model', '').strip()
- generate_cfg = copy.deepcopy(cfg.get('generate_cfg', {}))
- self.max_retries = generate_cfg.pop('max_retries', 0)
- self.generate_cfg = generate_cfg
- def quick_chat(self, prompt: str) -> str:
- *_, responses = self.chat(messages=[Message(role=USER, content=prompt)])
- assert len(responses) == 1
- assert not responses[0].function_call
- assert isinstance(responses[0].content, str)
- return responses[0].content
- def chat(
- self,
- messages: List[Union[Message, Dict]],
- functions: Optional[List[Dict]] = None,
- stream: bool = True,
- delta_stream: bool = False,
- extra_generate_cfg: Optional[Dict] = None,
- ) -> Union[List[Message], List[Dict], Iterator[List[Message]], Iterator[List[Dict]]]:
- """LLM chat interface.
- Args:
- messages: Inputted messages.
- functions: Inputted functions for function calling. OpenAI format supported.
- stream: Whether to use streaming generation.
- delta_stream: Whether to stream the response incrementally.
- (1) When False (recommended): Stream the full response every iteration.
- (2) When True: Stream the chunked response, i.e, delta responses.
- extra_generate_cfg: Extra LLM generation hyper-paramters.
- Returns:
- the generated message list response by llm.
- """
- if stream and delta_stream:
- logger.warning(
- 'Support for `delta_stream=True` is deprecated. '
- 'Please use `stream=True and delta_stream=False` or `stream=False` instead. '
- 'Using `delta_stream=True` makes it difficult to implement advanced postprocessing and retry mechanisms.'
- )
- generate_cfg = merge_generate_cfgs(base_generate_cfg=self.generate_cfg, new_generate_cfg=extra_generate_cfg)
- if 'seed' not in generate_cfg:
- generate_cfg['seed'] = random.randint(a=0, b=2**30)
- if 'lang' in generate_cfg:
- lang: Literal['en', 'zh'] = generate_cfg.pop('lang')
- else:
- lang: Literal['en', 'zh'] = 'zh' if has_chinese_messages(messages) else 'en'
- messages = copy.deepcopy(messages)
- _return_message_type = 'dict'
- new_messages = []
- for msg in messages:
- if isinstance(msg, dict):
- new_messages.append(Message(**msg))
- else:
- new_messages.append(msg)
- _return_message_type = 'message'
- messages = new_messages
- if messages[0].role != SYSTEM:
- messages = [Message(role=SYSTEM, content=DEFAULT_SYSTEM_MESSAGE)] + messages
- # Not precise. It's hard to estimate tokens related with function calling and multimodal items.
- max_input_tokens = generate_cfg.pop('max_input_tokens', DEFAULT_MAX_INPUT_TOKENS)
- if max_input_tokens > 0:
- messages = _truncate_input_messages_roughly(
- messages=messages,
- max_tokens=max_input_tokens,
- )
- if functions:
- fncall_mode = True
- else:
- fncall_mode = False
- if 'function_choice' in generate_cfg:
- fn_choice = generate_cfg['function_choice']
- valid_fn_choices = [f.get('name', f.get('name_for_model', None)) for f in (functions or [])]
- valid_fn_choices = ['auto', 'none'] + [f for f in valid_fn_choices if f]
- if fn_choice not in valid_fn_choices:
- raise ValueError(f'The value of function_choice must be one of the following: {valid_fn_choices}. '
- f'But function_choice="{fn_choice}" is received.')
- if fn_choice == 'none':
- fncall_mode = False
- # Note: the preprocessor's behavior could change if it receives function_choice="none"
- messages = self._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg, functions=functions)
- if not self.support_multimodal_input:
- messages = [format_as_text_message(msg, add_upload_info=False) for msg in messages]
- if not fncall_mode:
- for k in ['parallel_function_calls', 'function_choice']:
- if k in generate_cfg:
- del generate_cfg[k]
- def _call_model_service():
- if fncall_mode:
- return self._chat_with_functions(
- messages=messages,
- functions=functions,
- stream=stream,
- delta_stream=delta_stream,
- generate_cfg=generate_cfg,
- lang=lang,
- )
- else:
- return self._chat(
- messages,
- stream=stream,
- delta_stream=delta_stream,
- generate_cfg=generate_cfg,
- )
- if stream and delta_stream:
- # No retry for delta streaming
- output = _call_model_service()
- elif stream and (not delta_stream):
- output = retry_model_service_iterator(_call_model_service, max_retries=self.max_retries)
- else:
- output = retry_model_service(_call_model_service, max_retries=self.max_retries)
- if isinstance(output, list):
- assert not stream
- logger.debug(f'LLM Output:\n{pformat([_.model_dump() for _ in output], indent=2)}')
- output = self._postprocess_messages(output, fncall_mode=fncall_mode, generate_cfg=generate_cfg)
- if not self.support_multimodal_output:
- output = _format_as_text_messages(messages=output)
- return self._convert_messages_to_target_type(output, _return_message_type)
- else:
- assert stream
- if delta_stream:
- # Hack: To avoid potential errors during the postprocessing of stop words when delta_stream=True.
- # Man, we should never have implemented the support for `delta_stream=True` in the first place!
- generate_cfg = copy.deepcopy(generate_cfg) # copy to avoid conflicts with `_call_model_service`
- assert 'skip_stopword_postproc' not in generate_cfg
- generate_cfg['skip_stopword_postproc'] = True
- output = self._postprocess_messages_iterator(output, fncall_mode=fncall_mode, generate_cfg=generate_cfg)
- return self._convert_messages_iterator_to_target_type(output, _return_message_type)
- def _chat(
- self,
- messages: List[Union[Message, Dict]],
- stream: bool,
- delta_stream: bool,
- generate_cfg: dict,
- ) -> Union[List[Message], Iterator[List[Message]]]:
- if stream:
- return self._chat_stream(messages, delta_stream=delta_stream, generate_cfg=generate_cfg)
- else:
- return self._chat_no_stream(messages, generate_cfg=generate_cfg)
- @abstractmethod
- def _chat_with_functions(
- self,
- messages: List[Union[Message, Dict]],
- functions: List[Dict],
- stream: bool,
- delta_stream: bool,
- generate_cfg: dict,
- lang: Literal['en', 'zh'],
- ) -> Union[List[Message], Iterator[List[Message]]]:
- raise NotImplementedError
- @abstractmethod
- def _chat_stream(
- self,
- messages: List[Message],
- delta_stream: bool,
- generate_cfg: dict,
- ) -> Iterator[List[Message]]:
- raise NotImplementedError
- @abstractmethod
- def _chat_no_stream(
- self,
- messages: List[Message],
- generate_cfg: dict,
- ) -> List[Message]:
- raise NotImplementedError
- def _preprocess_messages(
- self,
- messages: List[Message],
- lang: Literal['en', 'zh'],
- generate_cfg: dict,
- functions: Optional[List[Dict]] = None,
- ) -> List[Message]:
- messages = [format_as_multimodal_message(msg, add_upload_info=True, lang=lang) for msg in messages]
- return messages
- def _postprocess_messages(
- self,
- messages: List[Message],
- fncall_mode: bool,
- generate_cfg: dict,
- ) -> List[Message]:
- messages = [format_as_multimodal_message(msg, add_upload_info=False) for msg in messages]
- if not generate_cfg.get('skip_stopword_postproc', False):
- stop = generate_cfg.get('stop', [])
- messages = _postprocess_stop_words(messages, stop=stop)
- return messages
- def _postprocess_messages_iterator(
- self,
- messages: Iterator[List[Message]],
- fncall_mode: bool,
- generate_cfg: dict,
- ) -> Iterator[List[Message]]:
- pre_msg = []
- for pre_msg in messages:
- post_msg = self._postprocess_messages(pre_msg, fncall_mode=fncall_mode, generate_cfg=generate_cfg)
- if not self.support_multimodal_output:
- post_msg = _format_as_text_messages(messages=post_msg)
- if post_msg:
- yield post_msg
- logger.debug(f'LLM Output:\n{pformat([_.model_dump() for _ in pre_msg], indent=2)}')
- def _convert_messages_to_target_type(self, messages: List[Message],
- target_type: str) -> Union[List[Message], List[Dict]]:
- if target_type == 'message':
- return [Message(**x) if isinstance(x, dict) else x for x in messages]
- elif target_type == 'dict':
- return [x.model_dump() if not isinstance(x, dict) else x for x in messages]
- else:
- raise NotImplementedError
- def _convert_messages_iterator_to_target_type(
- self, messages_iter: Iterator[List[Message]],
- target_type: str) -> Union[Iterator[List[Message]], Iterator[List[Dict]]]:
- for messages in messages_iter:
- yield self._convert_messages_to_target_type(messages, target_type)
- def _format_as_text_messages(messages: List[Message]) -> List[Message]:
- for msg in messages:
- if isinstance(msg.content, list):
- for item in msg.content:
- assert item.type == 'text'
- else:
- assert isinstance(msg.content, str)
- messages = [format_as_text_message(msg, add_upload_info=False) for msg in messages]
- return messages
- def _postprocess_stop_words(messages: List[Message], stop: List[str]) -> List[Message]:
- messages = copy.deepcopy(messages)
- # Make sure it stops before stop words.
- trunc_messages = []
- for msg in messages:
- truncated = False
- trunc_content = []
- for i, item in enumerate(msg.content):
- item_type, item_text = item.get_type_and_value()
- if item_type == 'text':
- truncated, item.text = _truncate_at_stop_word(text=item_text, stop=stop)
- trunc_content.append(item)
- if truncated:
- break
- msg.content = trunc_content
- trunc_messages.append(msg)
- if truncated:
- break
- messages = trunc_messages
- # It may ends with partial stopword 'Observation' when the full stopword is 'Observation:'.
- # The following post-processing step removes partial stop words.
- partial_stop = []
- for s in stop:
- s = tokenizer.tokenize(s)[:-1]
- if s:
- s = tokenizer.convert_tokens_to_string(s)
- partial_stop.append(s)
- partial_stop = sorted(set(partial_stop))
- last_msg = messages[-1].content
- for i in range(len(last_msg) - 1, -1, -1):
- item_type, item_text = last_msg[i].get_type_and_value()
- if item_type == 'text':
- for s in partial_stop:
- if item_text.endswith(s):
- last_msg[i].text = item_text[:-len(s)]
- break
- return messages
- def _truncate_at_stop_word(text: str, stop: List[str]):
- truncated = False
- for s in stop:
- k = text.find(s)
- if k >= 0:
- truncated = True
- text = text[:k]
- return truncated, text
- def _truncate_input_messages_roughly(messages: List[Message], max_tokens: int) -> List[Message]:
- sys_msg = messages[0]
- assert sys_msg.role == SYSTEM # The default system is prepended if none exists
- if len([m for m in messages if m.role == SYSTEM]) >= 2:
- raise ModelServiceError(
- code='400',
- message='The input messages must contain no more than one system message. '
- ' And the system message, if exists, must be the first message.',
- )
- turns = []
- for m in messages[1:]:
- if m.role == USER:
- turns.append([m])
- else:
- if turns:
- turns[-1].append(m)
- else:
- raise ModelServiceError(
- code='400',
- message='The input messages (excluding the system message) must start with a user message.',
- )
- def _count_tokens(msg: Message) -> int:
- return tokenizer.count_tokens(extract_text_from_message(msg, add_upload_info=True))
- token_cnt = _count_tokens(sys_msg)
- truncated = []
- for i, turn in enumerate(reversed(turns)):
- cur_turn_msgs = []
- cur_token_cnt = 0
- for m in reversed(turn):
- cur_turn_msgs.append(m)
- cur_token_cnt += _count_tokens(m)
- # Check "i == 0" so that at least one user message is included
- if (i == 0) or (token_cnt + cur_token_cnt <= max_tokens):
- truncated.extend(cur_turn_msgs)
- token_cnt += cur_token_cnt
- else:
- break
- # Always include the system message
- truncated.append(sys_msg)
- truncated.reverse()
- if len(truncated) < 2: # one system message + one or more user messages
- raise ModelServiceError(
- code='400',
- message='At least one user message should be provided.',
- )
- if token_cnt > max_tokens:
- raise ModelServiceError(
- code='400',
- message=f'The input messages exceed the maximum context length ({max_tokens} tokens) after '
- f'keeping only the system message and the latest one user message (around {token_cnt} tokens). '
- 'To configure the context limit, please specifiy "max_input_tokens" in the model generate_cfg. '
- f'Example: generate_cfg = {{..., "max_input_tokens": {(token_cnt // 100 + 1) * 100}}}',
- )
- return truncated
- def retry_model_service(
- fn,
- max_retries: int = 10,
- ) -> Any:
- """Retry a function"""
- num_retries, delay = 0, 1.0
- while True:
- try:
- return fn()
- except ModelServiceError as e:
- num_retries, delay = _raise_or_delay(e, num_retries, delay, max_retries)
- def retry_model_service_iterator(
- it_fn,
- max_retries: int = 10,
- ) -> Iterator:
- """Retry an iterator"""
- num_retries, delay = 0, 1.0
- while True:
- try:
- for rsp in it_fn():
- yield rsp
- break
- except ModelServiceError as e:
- num_retries, delay = _raise_or_delay(e, num_retries, delay, max_retries)
- def _raise_or_delay(
- e: ModelServiceError,
- num_retries: int,
- delay: float,
- max_retries: int = 10,
- max_delay: float = 300.0,
- exponential_base: float = 2.0,
- ) -> Tuple[int, float]:
- """Retry with exponential backoff"""
- if max_retries <= 0: # no retry
- raise e
- # Bad request, e.g., incorrect config or input
- if e.code == '400':
- raise e
- # If harmful input or output detected, let it fail
- if e.code == 'DataInspectionFailed':
- raise e
- if 'inappropriate content' in str(e):
- raise e
- # Retry is meaningless if the input is too long
- if 'maximum context length' in str(e):
- raise e
- logger.warning('ModelServiceError - ' + str(e).strip('\n'))
- if num_retries >= max_retries:
- raise ModelServiceError(exception=Exception(f'Maximum number of retries ({max_retries}) exceeded.'))
- num_retries += 1
- jitter = 1.0 + random.random()
- delay = min(delay * exponential_base, max_delay) * jitter
- time.sleep(delay)
- return num_retries, delay
|