base.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import copy
  2. import random
  3. import time
  4. from abc import ABC, abstractmethod
  5. from pprint import pformat
  6. from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
  7. from qwen_agent.llm.schema import DEFAULT_SYSTEM_MESSAGE, SYSTEM, USER, Message
  8. from qwen_agent.log import logger
  9. from qwen_agent.settings import DEFAULT_MAX_INPUT_TOKENS
  10. from qwen_agent.utils.tokenization_qwen import tokenizer
  11. from qwen_agent.utils.utils import (extract_text_from_message, format_as_multimodal_message, format_as_text_message,
  12. has_chinese_messages, merge_generate_cfgs)
  13. LLM_REGISTRY = {}
  14. def register_llm(model_type):
  15. def decorator(cls):
  16. LLM_REGISTRY[model_type] = cls
  17. return cls
  18. return decorator
  19. class ModelServiceError(Exception):
  20. def __init__(self,
  21. exception: Optional[Exception] = None,
  22. code: Optional[str] = None,
  23. message: Optional[str] = None):
  24. if exception is not None:
  25. super().__init__(exception)
  26. else:
  27. super().__init__(f'\nError code: {code}. Error message: {message}')
  28. self.exception = exception
  29. self.code = code
  30. self.message = message
  31. class BaseChatModel(ABC):
  32. """The base class of LLM"""
  33. @property
  34. def support_multimodal_input(self) -> bool:
  35. # Does the model support multimodal input natively? It affects how we preprocess the input.
  36. return False
  37. @property
  38. def support_multimodal_output(self) -> bool:
  39. # Does the model generate multimodal outputs beyond texts? It affects how we post-process the output.
  40. return False
  41. def __init__(self, cfg: Optional[Dict] = None):
  42. cfg = cfg or {}
  43. self.model = cfg.get('model', '').strip()
  44. generate_cfg = copy.deepcopy(cfg.get('generate_cfg', {}))
  45. self.max_retries = generate_cfg.pop('max_retries', 0)
  46. self.generate_cfg = generate_cfg
  47. def quick_chat(self, prompt: str) -> str:
  48. *_, responses = self.chat(messages=[Message(role=USER, content=prompt)])
  49. assert len(responses) == 1
  50. assert not responses[0].function_call
  51. assert isinstance(responses[0].content, str)
  52. return responses[0].content
  53. def chat(
  54. self,
  55. messages: List[Union[Message, Dict]],
  56. functions: Optional[List[Dict]] = None,
  57. stream: bool = True,
  58. delta_stream: bool = False,
  59. extra_generate_cfg: Optional[Dict] = None,
  60. ) -> Union[List[Message], List[Dict], Iterator[List[Message]], Iterator[List[Dict]]]:
  61. """LLM chat interface.
  62. Args:
  63. messages: Inputted messages.
  64. functions: Inputted functions for function calling. OpenAI format supported.
  65. stream: Whether to use streaming generation.
  66. delta_stream: Whether to stream the response incrementally.
  67. (1) When False (recommended): Stream the full response every iteration.
  68. (2) When True: Stream the chunked response, i.e, delta responses.
  69. extra_generate_cfg: Extra LLM generation hyper-paramters.
  70. Returns:
  71. the generated message list response by llm.
  72. """
  73. if stream and delta_stream:
  74. logger.warning(
  75. 'Support for `delta_stream=True` is deprecated. '
  76. 'Please use `stream=True and delta_stream=False` or `stream=False` instead. '
  77. 'Using `delta_stream=True` makes it difficult to implement advanced postprocessing and retry mechanisms.'
  78. )
  79. generate_cfg = merge_generate_cfgs(base_generate_cfg=self.generate_cfg, new_generate_cfg=extra_generate_cfg)
  80. if 'seed' not in generate_cfg:
  81. generate_cfg['seed'] = random.randint(a=0, b=2**30)
  82. if 'lang' in generate_cfg:
  83. lang: Literal['en', 'zh'] = generate_cfg.pop('lang')
  84. else:
  85. lang: Literal['en', 'zh'] = 'zh' if has_chinese_messages(messages) else 'en'
  86. messages = copy.deepcopy(messages)
  87. _return_message_type = 'dict'
  88. new_messages = []
  89. for msg in messages:
  90. if isinstance(msg, dict):
  91. new_messages.append(Message(**msg))
  92. else:
  93. new_messages.append(msg)
  94. _return_message_type = 'message'
  95. messages = new_messages
  96. if messages[0].role != SYSTEM:
  97. messages = [Message(role=SYSTEM, content=DEFAULT_SYSTEM_MESSAGE)] + messages
  98. # Not precise. It's hard to estimate tokens related with function calling and multimodal items.
  99. max_input_tokens = generate_cfg.pop('max_input_tokens', DEFAULT_MAX_INPUT_TOKENS)
  100. if max_input_tokens > 0:
  101. messages = _truncate_input_messages_roughly(
  102. messages=messages,
  103. max_tokens=max_input_tokens,
  104. )
  105. if functions:
  106. fncall_mode = True
  107. else:
  108. fncall_mode = False
  109. if 'function_choice' in generate_cfg:
  110. fn_choice = generate_cfg['function_choice']
  111. valid_fn_choices = [f.get('name', f.get('name_for_model', None)) for f in (functions or [])]
  112. valid_fn_choices = ['auto', 'none'] + [f for f in valid_fn_choices if f]
  113. if fn_choice not in valid_fn_choices:
  114. raise ValueError(f'The value of function_choice must be one of the following: {valid_fn_choices}. '
  115. f'But function_choice="{fn_choice}" is received.')
  116. if fn_choice == 'none':
  117. fncall_mode = False
  118. # Note: the preprocessor's behavior could change if it receives function_choice="none"
  119. messages = self._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg, functions=functions)
  120. if not self.support_multimodal_input:
  121. messages = [format_as_text_message(msg, add_upload_info=False) for msg in messages]
  122. if not fncall_mode:
  123. for k in ['parallel_function_calls', 'function_choice']:
  124. if k in generate_cfg:
  125. del generate_cfg[k]
  126. def _call_model_service():
  127. if fncall_mode:
  128. return self._chat_with_functions(
  129. messages=messages,
  130. functions=functions,
  131. stream=stream,
  132. delta_stream=delta_stream,
  133. generate_cfg=generate_cfg,
  134. lang=lang,
  135. )
  136. else:
  137. return self._chat(
  138. messages,
  139. stream=stream,
  140. delta_stream=delta_stream,
  141. generate_cfg=generate_cfg,
  142. )
  143. if stream and delta_stream:
  144. # No retry for delta streaming
  145. output = _call_model_service()
  146. elif stream and (not delta_stream):
  147. output = retry_model_service_iterator(_call_model_service, max_retries=self.max_retries)
  148. else:
  149. output = retry_model_service(_call_model_service, max_retries=self.max_retries)
  150. if isinstance(output, list):
  151. assert not stream
  152. logger.debug(f'LLM Output:\n{pformat([_.model_dump() for _ in output], indent=2)}')
  153. output = self._postprocess_messages(output, fncall_mode=fncall_mode, generate_cfg=generate_cfg)
  154. if not self.support_multimodal_output:
  155. output = _format_as_text_messages(messages=output)
  156. return self._convert_messages_to_target_type(output, _return_message_type)
  157. else:
  158. assert stream
  159. if delta_stream:
  160. # Hack: To avoid potential errors during the postprocessing of stop words when delta_stream=True.
  161. # Man, we should never have implemented the support for `delta_stream=True` in the first place!
  162. generate_cfg = copy.deepcopy(generate_cfg) # copy to avoid conflicts with `_call_model_service`
  163. assert 'skip_stopword_postproc' not in generate_cfg
  164. generate_cfg['skip_stopword_postproc'] = True
  165. output = self._postprocess_messages_iterator(output, fncall_mode=fncall_mode, generate_cfg=generate_cfg)
  166. return self._convert_messages_iterator_to_target_type(output, _return_message_type)
  167. def _chat(
  168. self,
  169. messages: List[Union[Message, Dict]],
  170. stream: bool,
  171. delta_stream: bool,
  172. generate_cfg: dict,
  173. ) -> Union[List[Message], Iterator[List[Message]]]:
  174. if stream:
  175. return self._chat_stream(messages, delta_stream=delta_stream, generate_cfg=generate_cfg)
  176. else:
  177. return self._chat_no_stream(messages, generate_cfg=generate_cfg)
  178. @abstractmethod
  179. def _chat_with_functions(
  180. self,
  181. messages: List[Union[Message, Dict]],
  182. functions: List[Dict],
  183. stream: bool,
  184. delta_stream: bool,
  185. generate_cfg: dict,
  186. lang: Literal['en', 'zh'],
  187. ) -> Union[List[Message], Iterator[List[Message]]]:
  188. raise NotImplementedError
  189. @abstractmethod
  190. def _chat_stream(
  191. self,
  192. messages: List[Message],
  193. delta_stream: bool,
  194. generate_cfg: dict,
  195. ) -> Iterator[List[Message]]:
  196. raise NotImplementedError
  197. @abstractmethod
  198. def _chat_no_stream(
  199. self,
  200. messages: List[Message],
  201. generate_cfg: dict,
  202. ) -> List[Message]:
  203. raise NotImplementedError
  204. def _preprocess_messages(
  205. self,
  206. messages: List[Message],
  207. lang: Literal['en', 'zh'],
  208. generate_cfg: dict,
  209. functions: Optional[List[Dict]] = None,
  210. ) -> List[Message]:
  211. messages = [format_as_multimodal_message(msg, add_upload_info=True, lang=lang) for msg in messages]
  212. return messages
  213. def _postprocess_messages(
  214. self,
  215. messages: List[Message],
  216. fncall_mode: bool,
  217. generate_cfg: dict,
  218. ) -> List[Message]:
  219. messages = [format_as_multimodal_message(msg, add_upload_info=False) for msg in messages]
  220. if not generate_cfg.get('skip_stopword_postproc', False):
  221. stop = generate_cfg.get('stop', [])
  222. messages = _postprocess_stop_words(messages, stop=stop)
  223. return messages
  224. def _postprocess_messages_iterator(
  225. self,
  226. messages: Iterator[List[Message]],
  227. fncall_mode: bool,
  228. generate_cfg: dict,
  229. ) -> Iterator[List[Message]]:
  230. pre_msg = []
  231. for pre_msg in messages:
  232. post_msg = self._postprocess_messages(pre_msg, fncall_mode=fncall_mode, generate_cfg=generate_cfg)
  233. if not self.support_multimodal_output:
  234. post_msg = _format_as_text_messages(messages=post_msg)
  235. if post_msg:
  236. yield post_msg
  237. logger.debug(f'LLM Output:\n{pformat([_.model_dump() for _ in pre_msg], indent=2)}')
  238. def _convert_messages_to_target_type(self, messages: List[Message],
  239. target_type: str) -> Union[List[Message], List[Dict]]:
  240. if target_type == 'message':
  241. return [Message(**x) if isinstance(x, dict) else x for x in messages]
  242. elif target_type == 'dict':
  243. return [x.model_dump() if not isinstance(x, dict) else x for x in messages]
  244. else:
  245. raise NotImplementedError
  246. def _convert_messages_iterator_to_target_type(
  247. self, messages_iter: Iterator[List[Message]],
  248. target_type: str) -> Union[Iterator[List[Message]], Iterator[List[Dict]]]:
  249. for messages in messages_iter:
  250. yield self._convert_messages_to_target_type(messages, target_type)
  251. def _format_as_text_messages(messages: List[Message]) -> List[Message]:
  252. for msg in messages:
  253. if isinstance(msg.content, list):
  254. for item in msg.content:
  255. assert item.type == 'text'
  256. else:
  257. assert isinstance(msg.content, str)
  258. messages = [format_as_text_message(msg, add_upload_info=False) for msg in messages]
  259. return messages
  260. def _postprocess_stop_words(messages: List[Message], stop: List[str]) -> List[Message]:
  261. messages = copy.deepcopy(messages)
  262. # Make sure it stops before stop words.
  263. trunc_messages = []
  264. for msg in messages:
  265. truncated = False
  266. trunc_content = []
  267. for i, item in enumerate(msg.content):
  268. item_type, item_text = item.get_type_and_value()
  269. if item_type == 'text':
  270. truncated, item.text = _truncate_at_stop_word(text=item_text, stop=stop)
  271. trunc_content.append(item)
  272. if truncated:
  273. break
  274. msg.content = trunc_content
  275. trunc_messages.append(msg)
  276. if truncated:
  277. break
  278. messages = trunc_messages
  279. # It may ends with partial stopword 'Observation' when the full stopword is 'Observation:'.
  280. # The following post-processing step removes partial stop words.
  281. partial_stop = []
  282. for s in stop:
  283. s = tokenizer.tokenize(s)[:-1]
  284. if s:
  285. s = tokenizer.convert_tokens_to_string(s)
  286. partial_stop.append(s)
  287. partial_stop = sorted(set(partial_stop))
  288. last_msg = messages[-1].content
  289. for i in range(len(last_msg) - 1, -1, -1):
  290. item_type, item_text = last_msg[i].get_type_and_value()
  291. if item_type == 'text':
  292. for s in partial_stop:
  293. if item_text.endswith(s):
  294. last_msg[i].text = item_text[:-len(s)]
  295. break
  296. return messages
  297. def _truncate_at_stop_word(text: str, stop: List[str]):
  298. truncated = False
  299. for s in stop:
  300. k = text.find(s)
  301. if k >= 0:
  302. truncated = True
  303. text = text[:k]
  304. return truncated, text
  305. def _truncate_input_messages_roughly(messages: List[Message], max_tokens: int) -> List[Message]:
  306. sys_msg = messages[0]
  307. assert sys_msg.role == SYSTEM # The default system is prepended if none exists
  308. if len([m for m in messages if m.role == SYSTEM]) >= 2:
  309. raise ModelServiceError(
  310. code='400',
  311. message='The input messages must contain no more than one system message. '
  312. ' And the system message, if exists, must be the first message.',
  313. )
  314. turns = []
  315. for m in messages[1:]:
  316. if m.role == USER:
  317. turns.append([m])
  318. else:
  319. if turns:
  320. turns[-1].append(m)
  321. else:
  322. raise ModelServiceError(
  323. code='400',
  324. message='The input messages (excluding the system message) must start with a user message.',
  325. )
  326. def _count_tokens(msg: Message) -> int:
  327. return tokenizer.count_tokens(extract_text_from_message(msg, add_upload_info=True))
  328. token_cnt = _count_tokens(sys_msg)
  329. truncated = []
  330. for i, turn in enumerate(reversed(turns)):
  331. cur_turn_msgs = []
  332. cur_token_cnt = 0
  333. for m in reversed(turn):
  334. cur_turn_msgs.append(m)
  335. cur_token_cnt += _count_tokens(m)
  336. # Check "i == 0" so that at least one user message is included
  337. if (i == 0) or (token_cnt + cur_token_cnt <= max_tokens):
  338. truncated.extend(cur_turn_msgs)
  339. token_cnt += cur_token_cnt
  340. else:
  341. break
  342. # Always include the system message
  343. truncated.append(sys_msg)
  344. truncated.reverse()
  345. if len(truncated) < 2: # one system message + one or more user messages
  346. raise ModelServiceError(
  347. code='400',
  348. message='At least one user message should be provided.',
  349. )
  350. if token_cnt > max_tokens:
  351. raise ModelServiceError(
  352. code='400',
  353. message=f'The input messages exceed the maximum context length ({max_tokens} tokens) after '
  354. f'keeping only the system message and the latest one user message (around {token_cnt} tokens). '
  355. 'To configure the context limit, please specifiy "max_input_tokens" in the model generate_cfg. '
  356. f'Example: generate_cfg = {{..., "max_input_tokens": {(token_cnt // 100 + 1) * 100}}}',
  357. )
  358. return truncated
  359. def retry_model_service(
  360. fn,
  361. max_retries: int = 10,
  362. ) -> Any:
  363. """Retry a function"""
  364. num_retries, delay = 0, 1.0
  365. while True:
  366. try:
  367. return fn()
  368. except ModelServiceError as e:
  369. num_retries, delay = _raise_or_delay(e, num_retries, delay, max_retries)
  370. def retry_model_service_iterator(
  371. it_fn,
  372. max_retries: int = 10,
  373. ) -> Iterator:
  374. """Retry an iterator"""
  375. num_retries, delay = 0, 1.0
  376. while True:
  377. try:
  378. for rsp in it_fn():
  379. yield rsp
  380. break
  381. except ModelServiceError as e:
  382. num_retries, delay = _raise_or_delay(e, num_retries, delay, max_retries)
  383. def _raise_or_delay(
  384. e: ModelServiceError,
  385. num_retries: int,
  386. delay: float,
  387. max_retries: int = 10,
  388. max_delay: float = 300.0,
  389. exponential_base: float = 2.0,
  390. ) -> Tuple[int, float]:
  391. """Retry with exponential backoff"""
  392. if max_retries <= 0: # no retry
  393. raise e
  394. # Bad request, e.g., incorrect config or input
  395. if e.code == '400':
  396. raise e
  397. # If harmful input or output detected, let it fail
  398. if e.code == 'DataInspectionFailed':
  399. raise e
  400. if 'inappropriate content' in str(e):
  401. raise e
  402. # Retry is meaningless if the input is too long
  403. if 'maximum context length' in str(e):
  404. raise e
  405. logger.warning('ModelServiceError - ' + str(e).strip('\n'))
  406. if num_retries >= max_retries:
  407. raise ModelServiceError(exception=Exception(f'Maximum number of retries ({max_retries}) exceeded.'))
  408. num_retries += 1
  409. jitter = 1.0 + random.random()
  410. delay = min(delay * exponential_base, max_delay) * jitter
  411. time.sleep(delay)
  412. return num_retries, delay