qwen_dashscope.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. from http import HTTPStatus
  3. from pprint import pformat
  4. from typing import Dict, Iterator, List, Optional
  5. import dashscope
  6. from qwen_agent.llm.base import ModelServiceError, register_llm
  7. from qwen_agent.llm.function_calling import BaseFnCallModel
  8. from qwen_agent.llm.schema import ASSISTANT, Message
  9. from qwen_agent.log import logger
  10. from qwen_agent.utils.utils import build_text_completion_prompt
  11. @register_llm('qwen_dashscope')
  12. class QwenChatAtDS(BaseFnCallModel):
  13. def __init__(self, cfg: Optional[Dict] = None):
  14. super().__init__(cfg)
  15. self.model = self.model or 'qwen-max'
  16. initialize_dashscope(cfg)
  17. def _chat_stream(
  18. self,
  19. messages: List[Message],
  20. delta_stream: bool,
  21. generate_cfg: dict,
  22. ) -> Iterator[List[Message]]:
  23. messages = [msg.model_dump() for msg in messages]
  24. logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
  25. response = dashscope.Generation.call(
  26. self.model,
  27. messages=messages, # noqa
  28. result_format='message',
  29. stream=True,
  30. **generate_cfg)
  31. if delta_stream:
  32. return self._delta_stream_output(response)
  33. else:
  34. return self._full_stream_output(response)
  35. def _chat_no_stream(
  36. self,
  37. messages: List[Message],
  38. generate_cfg: dict,
  39. ) -> List[Message]:
  40. messages = [msg.model_dump() for msg in messages]
  41. logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
  42. response = dashscope.Generation.call(
  43. self.model,
  44. messages=messages, # noqa
  45. result_format='message',
  46. stream=False,
  47. **generate_cfg)
  48. if response.status_code == HTTPStatus.OK:
  49. return [Message(ASSISTANT, response.output.choices[0].message.content)]
  50. else:
  51. raise ModelServiceError(code=response.code, message=response.message)
  52. def _continue_assistant_response(
  53. self,
  54. messages: List[Message],
  55. generate_cfg: dict,
  56. stream: bool,
  57. ) -> Iterator[List[Message]]:
  58. prompt = build_text_completion_prompt(messages)
  59. logger.debug(f'LLM Input:\n{pformat(prompt, indent=2)}')
  60. response = dashscope.Generation.call(
  61. self.model,
  62. prompt=prompt, # noqa
  63. result_format='message',
  64. stream=True,
  65. use_raw_prompt=True,
  66. **generate_cfg)
  67. it = self._full_stream_output(response)
  68. if stream:
  69. return it # streaming the response
  70. else:
  71. *_, final_response = it # return the final response without streaming
  72. return final_response
  73. @staticmethod
  74. def _delta_stream_output(response) -> Iterator[List[Message]]:
  75. last_len = 0
  76. delay_len = 5
  77. in_delay = False
  78. text = ''
  79. for chunk in response:
  80. if chunk.status_code == HTTPStatus.OK:
  81. text = chunk.output.choices[0].message.content
  82. if (len(text) - last_len) <= delay_len:
  83. in_delay = True
  84. continue
  85. else:
  86. in_delay = False
  87. real_text = text[:-delay_len]
  88. now_rsp = real_text[last_len:]
  89. yield [Message(ASSISTANT, now_rsp)]
  90. last_len = len(real_text)
  91. else:
  92. raise ModelServiceError(code=chunk.code, message=chunk.message)
  93. if text and (in_delay or (last_len != len(text))):
  94. yield [Message(ASSISTANT, text[last_len:])]
  95. @staticmethod
  96. def _full_stream_output(response) -> Iterator[List[Message]]:
  97. for chunk in response:
  98. if chunk.status_code == HTTPStatus.OK:
  99. yield [Message(ASSISTANT, chunk.output.choices[0].message.content)]
  100. else:
  101. raise ModelServiceError(code=chunk.code, message=chunk.message)
  102. def initialize_dashscope(cfg: Optional[Dict] = None) -> None:
  103. cfg = cfg or {}
  104. api_key = cfg.get('api_key', '')
  105. base_http_api_url = cfg.get('base_http_api_url', None)
  106. base_websocket_api_url = cfg.get('base_websocket_api_url', None)
  107. if not api_key:
  108. api_key = os.getenv('DASHSCOPE_API_KEY', 'EMPTY')
  109. if not base_http_api_url:
  110. base_http_api_url = os.getenv('DASHSCOPE_HTTP_URL', None)
  111. if not base_websocket_api_url:
  112. base_websocket_api_url = os.getenv('DASHSCOPE_WEBSOCKET_URL', None)
  113. api_key = api_key.strip()
  114. dashscope.api_key = api_key
  115. if base_http_api_url is not None:
  116. dashscope.base_http_api_url = base_http_api_url.strip()
  117. if base_websocket_api_url is not None:
  118. dashscope.base_websocket_api_url = base_websocket_api_url.strip()