summary_agent.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import asyncio
  2. import traceback
  3. from typing import List
  4. from httpx import RemoteProtocolError
  5. from qwen_agent.planning.planner import PlanResponseContextManager
  6. from qwen_agent.sub_agent.BaseSubAgent import BaseSubAgent
  7. from qwen_agent.messages.context_message import ChatResponseChoice, ChatResponseStreamChoice
  8. SYSTEM_PROMPT = """
  9. 你是一个对前面提到的分析过程和结果进行总结摘要的专家,给你一段用户的提问,以及之前代码执行的过程和执行结果,你可以对整个分析的过程进行总结摘要,来回答用户提出的问题。
  10. 那么我问你,请对前面提到的分析过程和结果进行总结摘要,回答用户的问题。
  11. 注意:
  12. 1. 如果用户查询的Question通过数据库查询,没有返回结果,请直接回答“数据库中没有查询到相关的数据”,不允许胡编乱造。如果在数据库中查询到相关的结果,请根据结果回答用户问题。
  13. 2. 请不要对show_case的查询到的记录做过与详细的描述;
  14. 3. 如果用户查询的Question没有涉及到数据库查询,比如GIS图层操作,请直接回答没有相关数据。
  15. """
  16. agents_prompt = dict({
  17. 'LandSupplySqlAgent': '返回的数据结果面积单位是亩,金额单位是万元, 楼面价和土地单价的单位是万元/平方米',
  18. 'LandUseSqlAgent': """
  19. 返回的数据结果面积单位是公顷,不要对其进行转换,不要对其四舍五入
  20. 各一级类包含以下二级类:
  21. 湿地:包含以下几种土地利用现状小类 红树林地、森林沼泽、灌丛沼泽、沼泽草地、沿海滩涂、内陆滩涂和沼泽地,
  22. 耕地:包含以下几种土地利用现状小类 水田,水浇地和旱地,
  23. 园林:包含以下几种土地利用现状小类 果园,茶园,橡胶园和其他园地
  24. 林地:包含以下几种土地利用现状小类 乔木林地,竹林地,灌木林地和其他林地
  25. 草地:包含以下几种土地利用现状小类 天然牧草地,人工牧草地和其他草地
  26. 城镇村及工矿用地:包含以下几种土地利用现状小类 城市用地,建制镇用地,村庄用地,采矿用地和风景名胜及特殊用地
  27. 交通运输用地:包含以下几种土地利用现状小类 铁路用地,轨道交通用地,公路用地,农村道路,机场用地,港口码头用地和管道运输用地
  28. 水域及水利设施用地:包含以下几种土地利用现状小类 河流水面,湖泊水面,水库水面,坑塘水面,沟渠,水工建筑用地和冰川及常年积雪
  29. """,
  30. 'SpatialAnalysisAgent': """
  31. 1. 总结时不要输出图形的wkt信息或者其他坐标点信息
  32. """,
  33. 'GisLayerOperationAgent':"""
  34. 不对结果进行总结,请直接回答没有相关数据
  35. """,
  36. 'LandSiteSelectionSqlAgent': """
  37. 必须按照markdown格式输出的地块信息。以下是输出信息的参考,请将<>替换成真实的内容:
  38. ### 1.<地块名称>
  39. ### 2.<地块名称>
  40. 下面是输出时的注意事项:
  41. 1.生成markdown 注意必须严格使用换行符,不可章节之间出现没有换行符的情况
  42. 2.必须严格按照上面的结构输出信息
  43. 3.不要输出除上面结构之外的信息
  44. """,
  45. })
  46. class SummaryAgent(BaseSubAgent):
  47. def __init__(self, llm=None, llm_name=None, stream=False, name='summary'):
  48. super().__init__(llm=llm, llm_name=llm_name, stream=stream, name=name)
  49. async def run(self, plan_context: PlanResponseContextManager, messages: List[str]):
  50. query = plan_context.get_summary_context()
  51. _messages = [{
  52. "role": "system",
  53. "content": self.handle_prompt(plan_context, SYSTEM_PROMPT)
  54. }]
  55. _messages.extend(messages)
  56. _messages.append({
  57. 'role': 'user',
  58. 'content': query
  59. })
  60. # print(f"query of summary agent: {query}")
  61. # for msg in _messages:
  62. for i, msg in enumerate(_messages):
  63. if not isinstance(msg, dict):
  64. msg = dict(msg)
  65. if msg['type'].value == 1:
  66. msg['role'] = 'user'
  67. msg['content'] = msg['data']
  68. else:
  69. msg['role'] = 'assistant'
  70. msg['content'] = dict(msg['data'])['exec_res'][0]
  71. msg['history'] = True
  72. del msg['data']
  73. del msg['type']
  74. _messages[i] = msg
  75. if 'history' in msg and msg['history']:
  76. print('is history messsage')
  77. else:
  78. yield ChatResponseChoice(role=msg['role'], content=msg['content'])
  79. retry_round = 0
  80. self.exec_res = 'Summary error, please try again...'
  81. self.is_success = False
  82. while retry_round < self.max_retry_round:
  83. try:
  84. rsp = await self.llm.chat(model=self.llm_name, messages=_messages, stream=self.stream)
  85. if self.stream:
  86. res = ""
  87. async for chunk in rsp:
  88. res += chunk
  89. yield ChatResponseStreamChoice(role='assistant', delta=chunk)
  90. yield ChatResponseStreamChoice(role='assistant', finish_reason='stop')
  91. else:
  92. res = rsp
  93. yield ChatResponseChoice(role='assistant', content=rsp)
  94. print(f'summary input: {query} \n summary output: {res}')
  95. self.exec_res = res
  96. self.is_success = True
  97. break
  98. except Exception as e:
  99. traceback.print_exc()
  100. if self.stream:
  101. yield ChatResponseStreamChoice(role='assistant', finish_reason='flush')
  102. if isinstance(e, RemoteProtocolError):
  103. await asyncio.sleep(2 ** self.max_retry_round + 2 ** (self.max_retry_round - retry_round + 1))
  104. retry_round += 1
  105. def handle_prompt(self, plan_context: PlanResponseContextManager, prompt: str):
  106. plans = plan_context.plans
  107. for i, plan in enumerate(plans):
  108. if plan.action_name in dict.keys(agents_prompt):
  109. return prompt + agents_prompt[plan.action_name]
  110. return prompt