azure.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import os
  2. from typing import Dict, Optional
  3. import openai
  4. from qwen_agent.llm.base import register_llm
  5. from qwen_agent.llm.oai import TextChatAtOAI
  6. @register_llm('azure')
  7. class TextChatAtAzure(TextChatAtOAI):
  8. def __init__(self, cfg: Optional[Dict] = None):
  9. super().__init__(cfg)
  10. cfg = cfg or {}
  11. api_base = cfg.get('api_base')
  12. api_base = api_base or cfg.get('base_url')
  13. api_base = api_base or cfg.get('model_server')
  14. api_base = api_base or cfg.get('azure_endpoint')
  15. api_base = (api_base or '').strip()
  16. api_key = cfg.get('api_key')
  17. api_key = api_key or os.getenv('OPENAI_API_KEY')
  18. api_key = (api_key or 'EMPTY').strip()
  19. api_version = cfg.get('api_version', '2024-06-01')
  20. api_kwargs = {}
  21. if api_base:
  22. api_kwargs['azure_endpoint'] = api_base
  23. if api_key:
  24. api_kwargs['api_key'] = api_key
  25. if api_version:
  26. api_kwargs['api_version'] = api_version
  27. def _chat_complete_create(*args, **kwargs):
  28. client = openai.AzureOpenAI(**api_kwargs)
  29. return client.chat.completions.create(*args, **kwargs)
  30. self._chat_complete_create = _chat_complete_create