__init__.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import copy
  2. from typing import Union
  3. from .azure import TextChatAtAzure
  4. from .base import LLM_REGISTRY, BaseChatModel, ModelServiceError
  5. from .oai import TextChatAtOAI
  6. from .openvino import OpenVINO
  7. from .qwen_dashscope import QwenChatAtDS
  8. from .qwenvl_dashscope import QwenVLChatAtDS
  9. from .qwenvl_oai import QwenVLChatAtOAI
  10. def get_chat_model(cfg: Union[dict, str] = 'qwen-plus') -> BaseChatModel:
  11. """The interface of instantiating LLM objects.
  12. Args:
  13. cfg: The LLM configuration, one example is:
  14. cfg = {
  15. # Use the model service provided by DashScope:
  16. 'model': 'qwen-max',
  17. 'model_server': 'dashscope',
  18. # Use your own model service compatible with OpenAI API:
  19. # 'model': 'Qwen',
  20. # 'model_server': 'http://127.0.0.1:7905/v1',
  21. # (Optional) LLM hyper-parameters:
  22. 'generate_cfg': {
  23. 'top_p': 0.8,
  24. 'max_input_tokens': 6500,
  25. 'max_retries': 10,
  26. }
  27. }
  28. Returns:
  29. LLM object.
  30. """
  31. if isinstance(cfg, str):
  32. cfg = {'model': cfg}
  33. if 'model_type' in cfg:
  34. model_type = cfg['model_type']
  35. if model_type in LLM_REGISTRY:
  36. if model_type in ('oai', 'qwenvl_oai'):
  37. if cfg.get('model_server', '').strip() == 'dashscope':
  38. cfg = copy.deepcopy(cfg)
  39. cfg['model_server'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
  40. return LLM_REGISTRY[model_type](cfg)
  41. else:
  42. raise ValueError(f'Please set model_type from {str(LLM_REGISTRY.keys())}')
  43. # Deduce model_type from model and model_server if model_type is not provided:
  44. if 'azure_endpoint' in cfg:
  45. model_type = 'azure'
  46. return LLM_REGISTRY[model_type](cfg)
  47. if 'model_server' in cfg:
  48. if cfg['model_server'].strip().startswith('http'):
  49. model_type = 'oai'
  50. return LLM_REGISTRY[model_type](cfg)
  51. model = cfg.get('model', '')
  52. if 'qwen-vl' in model:
  53. model_type = 'qwenvl_dashscope'
  54. return LLM_REGISTRY[model_type](cfg)
  55. if 'qwen' in model:
  56. model_type = 'qwen_dashscope'
  57. return LLM_REGISTRY[model_type](cfg)
  58. raise ValueError(f'Invalid model cfg: {cfg}')
  59. __all__ = [
  60. 'BaseChatModel',
  61. 'QwenChatAtDS',
  62. 'TextChatAtOAI',
  63. 'TextChatAtAzure',
  64. 'QwenVLChatAtDS',
  65. 'QwenVLChatAtOAI',
  66. 'OpenVINO',
  67. 'get_chat_model',
  68. 'ModelServiceError',
  69. ]