run_server_async.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import shutil
  2. import tempfile
  3. import time
  4. import sys
  5. from zipfile import ZipFile
  6. from sse_starlette.sse import EventSourceResponse
  7. from fastapi import FastAPI, UploadFile, File, Form
  8. from fastapi.middleware.cors import CORSMiddleware
  9. import uvicorn
  10. import os
  11. from qwen_agent.gis.utils.base_class import Geometry
  12. from qwen_agent.gis.utils.geometry_parser import GeometryParser
  13. os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
  14. parent_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
  15. sys.path.append(parent_dir)
  16. from qwen_agent.planning.plan_executor import PlanExecutor
  17. from qwen_agent.planning.plan_continue_executor import PlanContinueExecutor
  18. from qwen_agent.llm.llm_client import LLMClient, LLMAsyncClient
  19. from agent_config import LLMDict_Qwen_72B_1211, LLMDict_GPT4_TURBO
  20. from agent_messages import BaseRequest
  21. from xuanzhi_query import router as xz_router
  22. prompt_lan = "CN"
  23. llm_name = "qwen-plus"
  24. llm_turbo_name = "gpt-4-turbo"
  25. max_ref_token = 4000
  26. # model_server = "http://10.10.0.10:7907/v1"
  27. # model_server = "http://lq.lianqiai.cn:7905/v1"
  28. # model_server = "http://172.20.28.16:20331/v1"
  29. model_server = "http://ac.zjugis.com:8511/v1"
  30. api_key = ""
  31. server_host = "0.0.0.0"
  32. server_port = 8511
  33. app = FastAPI()
  34. app.add_middleware(
  35. CORSMiddleware,
  36. allow_origins=["*"],
  37. allow_credentials=True,
  38. allow_methods=["*"],
  39. allow_headers=["*"],
  40. )
  41. app.include_router(xz_router)
  42. rspHeaders = {
  43. "Cache-Control": "no-cache",
  44. "Connection": "keep-alive",
  45. "Content-Type": "text/event-stream",
  46. "Transfer-Encoding": "chunked",
  47. }
  48. if model_server.startswith("http"):
  49. source = "local"
  50. elif model_server.startswith("dashscope"):
  51. source = "dashscope"
  52. def rm_file(file_path):
  53. if os.path.exists(file_path):
  54. os.remove(file_path)
  55. @app.post("/")
  56. def index():
  57. return "Welcome to Lianqi AI"
  58. # 空间分析助手接口
  59. @app.post("/subscribe/spatial")
  60. async def upload_file(question: str = Form(...), history: list = Form(...), file: UploadFile = File(...)):
  61. print(f'history: {history}')
  62. # 确保上传的文件是ZIP类型
  63. if file.content_type != "application/zip":
  64. return {"detail": "请上传ZIP文件"}
  65. # 将文件写入临时文件
  66. temp_dir = tempfile.TemporaryDirectory(dir=os.path.dirname(__file__) + "/upload_temps")
  67. with open(f"{temp_dir.name}/{file.filename}", "wb") as buffer:
  68. shutil.copyfileobj(file.file, buffer)
  69. print(buffer)
  70. # 解压缩
  71. with ZipFile(f"{temp_dir.name}/{file.filename}", "r") as zip_ref:
  72. zip_ref.extractall(temp_dir.name)
  73. shp_file_path = ''
  74. for root, dirs, files in os.walk(temp_dir.name):
  75. for file in files:
  76. if file.endswith('.shp'):
  77. shp_file_path = os.path.join(root, file)
  78. geoms = GeometryParser.parse_geom_shp_file(shp_file_path)
  79. return EventSourceResponse(
  80. call_with_stream(f"原问题为:{question},上传的图形信息为: {Geometry.dumps_json(geoms)}"),
  81. media_type="text/event-stream",
  82. headers=rspHeaders,
  83. )
  84. # 连续对话
  85. @app.post("/subscribe/history", response_model=str)
  86. async def subscribe_with_history(request: BaseRequest):
  87. print(request)
  88. return EventSourceResponse(
  89. call_with_stream(request.data, request.history),
  90. media_type="text/event-stream",
  91. headers=rspHeaders,
  92. )
  93. @app.post("/subscribe/", response_model=str)
  94. async def subscribe(request: BaseRequest):
  95. print(request)
  96. return EventSourceResponse(
  97. call_with_stream(request.data, request.history),
  98. media_type="text/event-stream",
  99. headers=rspHeaders,
  100. )
  101. @app.get("/subscribe/{question}", response_model=str)
  102. async def subscribe(question: str):
  103. return EventSourceResponse(
  104. call_with_stream(question),
  105. media_type="text/event-stream",
  106. headers=rspHeaders,
  107. )
  108. @app.post("/subscribeByTurbo/", response_model=str)
  109. async def subscribeByTurbo(question: BaseRequest):
  110. return EventSourceResponse(
  111. call_with_stream(question.data, False, LLMDict_GPT4_TURBO),
  112. media_type="text/event-stream",
  113. headers=rspHeaders,
  114. )
  115. @app.get("/subscribeByTurbo/{question}", response_model=str)
  116. async def subscribeByTurbo(question: str):
  117. return EventSourceResponse(
  118. call_with_stream(question, False, LLMDict_GPT4_TURBO),
  119. media_type="text/event-stream",
  120. headers=rspHeaders,
  121. )
  122. @app.post("/clarification/", response_model=str)
  123. async def clarification(request: BaseRequest):
  124. print("clarification: ", request)
  125. return EventSourceResponse(
  126. call_with_stream(request.data, True),
  127. media_type="text/event-stream",
  128. headers=rspHeaders,
  129. )
  130. @app.post("/clarificationByTurbo/", response_model=str)
  131. async def clarificationByTurbo(request: BaseRequest):
  132. print("clarificationByTurbo: ", request)
  133. return EventSourceResponse(
  134. call_with_stream(request.data, True, LLMDict_GPT4_TURBO),
  135. media_type="text/event-stream",
  136. headers=rspHeaders,
  137. )
  138. llm_client = LLMClient(model=llm_name, model_server=model_server)
  139. llm_client_async = LLMAsyncClient(model=llm_name, model_server=model_server)
  140. async def call_with_stream(
  141. question,
  142. history=[],
  143. isClarification=False,
  144. llm_dict=LLMDict_Qwen_72B_1211,
  145. ):
  146. for i, msg in enumerate(history):
  147. if not isinstance(msg, dict):
  148. msg = dict(msg)
  149. if msg['type'].value == 1:
  150. msg['role'] = 'user'
  151. msg['content'] = msg['data']
  152. else:
  153. msg['role'] = 'assistant'
  154. if isinstance(msg['data'], str):
  155. msg['content'] = msg['data']
  156. else:
  157. msg['content'] = dict(msg['data'])['exec_res'][0]
  158. msg['history'] = True
  159. del msg['data']
  160. del msg['type']
  161. history[i] = msg
  162. if isClarification:
  163. executor = PlanContinueExecutor(
  164. llm_dict=llm_dict, llm=llm_client_async, stream=True
  165. )
  166. else:
  167. executor = PlanExecutor(llm_dict=llm_dict, llm=llm_client_async, stream=True)
  168. async for rsp in executor.run(question, history):
  169. if not rsp:
  170. continue
  171. else:
  172. time.sleep(0.1)
  173. yield rsp
  174. yield "[DONE]"
  175. yield "[FINISH]"
  176. if __name__ == "__main__":
  177. # uvicorn.run("run_server_async:app", host=server_host, port=server_port, workers=5)
  178. uvicorn.run(app, host=server_host, port=server_port, workers=1)