qwenvl_dashscope.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import copy
  2. import os
  3. import re
  4. from http import HTTPStatus
  5. from pprint import pformat
  6. from typing import Dict, Iterator, List, Optional
  7. import dashscope
  8. from qwen_agent.llm.base import ModelServiceError, register_llm
  9. from qwen_agent.llm.function_calling import BaseFnCallModel
  10. from qwen_agent.llm.qwen_dashscope import initialize_dashscope
  11. from qwen_agent.llm.schema import ContentItem, Message
  12. from qwen_agent.log import logger
  13. @register_llm('qwenvl_dashscope')
  14. class QwenVLChatAtDS(BaseFnCallModel):
  15. @property
  16. def support_multimodal_input(self) -> bool:
  17. return True
  18. def __init__(self, cfg: Optional[Dict] = None):
  19. super().__init__(cfg)
  20. self.model = self.model or 'qwen-vl-max'
  21. initialize_dashscope(cfg)
  22. def _chat_stream(
  23. self,
  24. messages: List[Message],
  25. delta_stream: bool,
  26. generate_cfg: dict,
  27. ) -> Iterator[List[Message]]:
  28. if delta_stream:
  29. raise NotImplementedError
  30. messages = _format_local_files(messages)
  31. messages = [msg.model_dump() for msg in messages]
  32. logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
  33. response = dashscope.MultiModalConversation.call(model=self.model,
  34. messages=messages,
  35. result_format='message',
  36. stream=True,
  37. **generate_cfg)
  38. for chunk in response:
  39. if chunk.status_code == HTTPStatus.OK:
  40. yield _extract_vl_response(chunk)
  41. else:
  42. raise ModelServiceError(code=chunk.code, message=chunk.message)
  43. def _chat_no_stream(
  44. self,
  45. messages: List[Message],
  46. generate_cfg: dict,
  47. ) -> List[Message]:
  48. messages = _format_local_files(messages)
  49. messages = [msg.model_dump() for msg in messages]
  50. logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
  51. response = dashscope.MultiModalConversation.call(model=self.model,
  52. messages=messages,
  53. result_format='message',
  54. stream=False,
  55. **generate_cfg)
  56. if response.status_code == HTTPStatus.OK:
  57. return _extract_vl_response(response=response)
  58. else:
  59. raise ModelServiceError(code=response.code, message=response.message)
  60. # DashScope Qwen-VL requires the following format for local files:
  61. # Linux & Mac: file:///home/images/test.png
  62. # Windows: file://D:/images/abc.png
  63. def _format_local_files(messages: List[Message]) -> List[Message]:
  64. messages = copy.deepcopy(messages)
  65. for msg in messages:
  66. if isinstance(msg.content, list):
  67. for item in msg.content:
  68. if item.image:
  69. fname = item.image
  70. if not fname.startswith((
  71. 'http://',
  72. 'https://',
  73. 'file://',
  74. 'data:', # base64 such as f"data:image/jpg;base64,{image_base64}"
  75. )):
  76. if fname.startswith('~'):
  77. fname = os.path.expanduser(fname)
  78. fname = os.path.abspath(fname)
  79. if os.path.isfile(fname):
  80. if re.match(r'^[A-Za-z]:\\', fname):
  81. fname = fname.replace('\\', '/')
  82. fname = 'file://' + fname
  83. item.image = fname
  84. return messages
  85. def _extract_vl_response(response) -> List[Message]:
  86. output = response.output.choices[0].message
  87. text_content = []
  88. for item in output.content:
  89. if isinstance(item, str):
  90. text_content.append(ContentItem(text=item))
  91. else:
  92. for k, v in item.items():
  93. if k in ('text', 'box'):
  94. text_content.append(ContentItem(text=v))
  95. return [Message(role=output.role, content=text_content)]