oai.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import copy
  2. import logging
  3. import os
  4. from pprint import pformat
  5. from typing import Dict, Iterator, List, Optional
  6. import openai
  7. if openai.__version__.startswith('0.'):
  8. from openai.error import OpenAIError # noqa
  9. else:
  10. from openai import OpenAIError
  11. from qwen_agent.llm.base import ModelServiceError, register_llm
  12. from qwen_agent.llm.function_calling import BaseFnCallModel
  13. from qwen_agent.llm.schema import ASSISTANT, Message
  14. from qwen_agent.log import logger
  15. @register_llm('oai')
  16. class TextChatAtOAI(BaseFnCallModel):
  17. def __init__(self, cfg: Optional[Dict] = None):
  18. super().__init__(cfg)
  19. self.model = self.model or 'gpt-4o-mini'
  20. cfg = cfg or {}
  21. api_base = cfg.get('api_base')
  22. api_base = api_base or cfg.get('base_url')
  23. api_base = api_base or cfg.get('model_server')
  24. api_base = (api_base or '').strip()
  25. api_key = cfg.get('api_key')
  26. api_key = api_key or os.getenv('OPENAI_API_KEY')
  27. api_key = (api_key or 'EMPTY').strip()
  28. if openai.__version__.startswith('0.'):
  29. if api_base:
  30. openai.api_base = api_base
  31. if api_key:
  32. openai.api_key = api_key
  33. self._chat_complete_create = openai.ChatCompletion.create
  34. else:
  35. api_kwargs = {}
  36. if api_base:
  37. api_kwargs['base_url'] = api_base
  38. if api_key:
  39. api_kwargs['api_key'] = api_key
  40. def _chat_complete_create(*args, **kwargs):
  41. # OpenAI API v1 does not allow the following args, must pass by extra_body
  42. extra_params = ['top_k', 'repetition_penalty']
  43. if any((k in kwargs) for k in extra_params):
  44. kwargs['extra_body'] = copy.deepcopy(kwargs.get('extra_body', {}))
  45. for k in extra_params:
  46. if k in kwargs:
  47. kwargs['extra_body'][k] = kwargs.pop(k)
  48. if 'request_timeout' in kwargs:
  49. kwargs['timeout'] = kwargs.pop('request_timeout')
  50. client = openai.OpenAI(**api_kwargs)
  51. return client.chat.completions.create(*args, **kwargs)
  52. self._chat_complete_create = _chat_complete_create
  53. def _chat_stream(
  54. self,
  55. messages: List[Message],
  56. delta_stream: bool,
  57. generate_cfg: dict,
  58. ) -> Iterator[List[Message]]:
  59. messages = self.convert_messages_to_dicts(messages)
  60. try:
  61. response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg)
  62. if delta_stream:
  63. for chunk in response:
  64. if chunk.choices and hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
  65. yield [Message(ASSISTANT, chunk.choices[0].delta.content)]
  66. else:
  67. full_response = ''
  68. for chunk in response:
  69. if chunk.choices and hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
  70. full_response += chunk.choices[0].delta.content
  71. yield [Message(ASSISTANT, full_response)]
  72. except OpenAIError as ex:
  73. raise ModelServiceError(exception=ex)
  74. def _chat_no_stream(
  75. self,
  76. messages: List[Message],
  77. generate_cfg: dict,
  78. ) -> List[Message]:
  79. messages = self.convert_messages_to_dicts(messages)
  80. try:
  81. response = self._chat_complete_create(model=self.model, messages=messages, stream=False, **generate_cfg)
  82. return [Message(ASSISTANT, response.choices[0].message.content)]
  83. except OpenAIError as ex:
  84. raise ModelServiceError(exception=ex)
  85. @staticmethod
  86. def convert_messages_to_dicts(messages: List[Message]) -> List[dict]:
  87. messages = [msg.model_dump() for msg in messages]
  88. if logger.isEnabledFor(logging.DEBUG):
  89. logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
  90. return messages