run_server_async.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import shutil
  2. import tempfile
  3. import time
  4. import sys
  5. from zipfile import ZipFile
  6. import fiona
  7. from shapely.geometry import shape
  8. from sse_starlette.sse import EventSourceResponse
  9. from fastapi import FastAPI, UploadFile, File, Form
  10. from fastapi.middleware.cors import CORSMiddleware
  11. import uvicorn
  12. import os
  13. from qwen_agent.gis.utils.base_class import Geometry
  14. from qwen_agent.gis.utils.geometry_parser import GeometryParser
  15. os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
  16. parent_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
  17. sys.path.append(parent_dir)
  18. from qwen_agent.planning.plan_executor import PlanExecutor
  19. from qwen_agent.planning.plan_continue_executor import PlanContinueExecutor
  20. from qwen_agent.llm.llm_client import LLMClient, LLMAsyncClient
  21. from agent_config import LLMDict_Qwen_72B_1211, LLMDict_GPT4_TURBO
  22. from agent_messages import BaseRequest
  23. prompt_lan = "CN"
  24. llm_name = "qwen-plus"
  25. llm_turbo_name = "gpt-4-turbo"
  26. max_ref_token = 4000
  27. # model_server = "http://10.10.0.10:7907/v1"
  28. # model_server = "http://lq.lianqiai.cn:7905/v1"
  29. model_server = "http://10.36.162.54:20331/v1"
  30. api_key = ""
  31. server_host = "0.0.0.0"
  32. server_port = 8511
  33. app = FastAPI()
  34. app.add_middleware(
  35. CORSMiddleware,
  36. allow_origins=["*"],
  37. allow_credentials=True,
  38. allow_methods=["*"],
  39. allow_headers=["*"],
  40. )
  41. rspHeaders = {
  42. "Cache-Control": "no-cache",
  43. "Connection": "keep-alive",
  44. "Content-Type": "text/event-stream",
  45. "Transfer-Encoding": "chunked",
  46. }
  47. if model_server.startswith("http"):
  48. source = "local"
  49. elif model_server.startswith("dashscope"):
  50. source = "dashscope"
  51. def rm_file(file_path):
  52. if os.path.exists(file_path):
  53. os.remove(file_path)
  54. @app.post("/")
  55. def index():
  56. return "Welcome to Lianqi AI"
  57. # 空间分析助手接口
  58. @app.post("/subscribe/spatial")
  59. async def upload_file(question: str = Form(...), history: list = Form(...), file: UploadFile = File(...)):
  60. print(f'history: {history}')
  61. # 确保上传的文件是ZIP类型
  62. if file.content_type != "application/zip":
  63. return {"detail": "请上传ZIP文件"}
  64. # 将文件写入临时文件
  65. temp_dir = tempfile.TemporaryDirectory(dir=os.path.dirname(__file__) + "/upload_temps")
  66. with open(f"{temp_dir.name}/{file.filename}", "wb") as buffer:
  67. shutil.copyfileobj(file.file, buffer)
  68. print(buffer)
  69. # 解压缩
  70. with ZipFile(f"{temp_dir.name}/{file.filename}", "r") as zip_ref:
  71. zip_ref.extractall(temp_dir.name)
  72. shp_file_path = ''
  73. for root, dirs, files in os.walk(temp_dir.name):
  74. for file in files:
  75. if file.endswith('.shp'):
  76. shp_file_path = os.path.join(root, file)
  77. geoms = GeometryParser.parse_geom_shp_file(shp_file_path)
  78. return EventSourceResponse(
  79. call_with_stream(f"原问题为:{question},上传的图形信息为: {Geometry.dumps_json(geoms)}"),
  80. media_type="text/event-stream",
  81. headers=rspHeaders,
  82. )
  83. # 连续对话
  84. @app.post("/subscribe/history", response_model=str)
  85. async def subscribe_with_history(request: BaseRequest):
  86. print(request)
  87. return EventSourceResponse(
  88. call_with_stream(request.data, request.history),
  89. media_type="text/event-stream",
  90. headers=rspHeaders,
  91. )
  92. @app.post("/subscribe/", response_model=str)
  93. async def subscribe(request: BaseRequest):
  94. print(request)
  95. return EventSourceResponse(
  96. call_with_stream(request.data, request.history),
  97. media_type="text/event-stream",
  98. headers=rspHeaders,
  99. )
  100. @app.get("/subscribe/{question}", response_model=str)
  101. async def subscribe(question: str):
  102. return EventSourceResponse(
  103. call_with_stream(question),
  104. media_type="text/event-stream",
  105. headers=rspHeaders,
  106. )
  107. @app.post("/subscribeByTurbo/", response_model=str)
  108. async def subscribeByTurbo(question: BaseRequest):
  109. return EventSourceResponse(
  110. call_with_stream(question.data, False, LLMDict_GPT4_TURBO),
  111. media_type="text/event-stream",
  112. headers=rspHeaders,
  113. )
  114. @app.get("/subscribeByTurbo/{question}", response_model=str)
  115. async def subscribeByTurbo(question: str):
  116. return EventSourceResponse(
  117. call_with_stream(question, False, LLMDict_GPT4_TURBO),
  118. media_type="text/event-stream",
  119. headers=rspHeaders,
  120. )
  121. @app.post("/clarification/", response_model=str)
  122. async def clarification(request: BaseRequest):
  123. print("clarification: ", request)
  124. return EventSourceResponse(
  125. call_with_stream(request.data, True),
  126. media_type="text/event-stream",
  127. headers=rspHeaders,
  128. )
  129. @app.post("/clarificationByTurbo/", response_model=str)
  130. async def clarificationByTurbo(request: BaseRequest):
  131. print("clarificationByTurbo: ", request)
  132. return EventSourceResponse(
  133. call_with_stream(request.data, True, LLMDict_GPT4_TURBO),
  134. media_type="text/event-stream",
  135. headers=rspHeaders,
  136. )
  137. llm_client = LLMClient(model=llm_name, model_server=model_server)
  138. llm_client_async = LLMAsyncClient(model=llm_name, model_server=model_server)
  139. async def call_with_stream(
  140. question,
  141. history=[],
  142. isClarification=False,
  143. llm_dict=LLMDict_Qwen_72B_1211,
  144. ):
  145. for i, msg in enumerate(history):
  146. if not isinstance(msg, dict):
  147. msg = dict(msg)
  148. if msg['type'].value == 1:
  149. msg['role'] = 'user'
  150. msg['content'] = msg['data']
  151. else:
  152. msg['role'] = 'assistant'
  153. if isinstance(msg['data'], str):
  154. msg['content'] = msg['data']
  155. else:
  156. msg['content'] = dict(msg['data'])['exec_res'][0]
  157. msg['history'] = True
  158. del msg['data']
  159. del msg['type']
  160. history[i] = msg
  161. if isClarification:
  162. executor = PlanContinueExecutor(
  163. llm_dict=llm_dict, llm=llm_client_async, stream=True
  164. )
  165. else:
  166. executor = PlanExecutor(llm_dict=llm_dict, llm=llm_client_async, stream=True)
  167. async for rsp in executor.run(question, history):
  168. if not rsp:
  169. continue
  170. else:
  171. time.sleep(0.1)
  172. yield rsp
  173. yield "[DONE]"
  174. yield "[FINISH]"
  175. if __name__ == "__main__":
  176. uvicorn.run(app, host=server_host, port=server_port, workers=1)