run_server_async.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import shutil
  2. import tempfile
  3. import time
  4. import sys
  5. from typing import List
  6. from zipfile import ZipFile
  7. import json
  8. import fiona
  9. from pydantic import BaseModel
  10. from shapely.geometry import shape
  11. from sse_starlette.sse import EventSourceResponse
  12. from fastapi import FastAPI, UploadFile, File, Form
  13. from fastapi.middleware.cors import CORSMiddleware
  14. import uvicorn
  15. import os
  16. from qwen_agent.gis.utils.base_class import Geometry
  17. from qwen_agent.gis.utils.geometry_parser import GeometryParser
  18. os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
  19. parent_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
  20. sys.path.append(parent_dir)
  21. from qwen_agent.planning.plan_executor import PlanExecutor
  22. from qwen_agent.planning.plan_continue_executor import PlanContinueExecutor
  23. from qwen_agent.llm.llm_client import LLMClient, LLMAsyncClient
  24. from agent_config import LLMDict_Qwen_72B_1211, LLMDict_GPT4_TURBO
  25. from agent_messages import BaseRequest
  26. from qwen_agent.tools.tools import async_xzdb, async_db
  27. from qwen_agent.tools.gis.spatial_analysis.geo_analysis import intersect_kfq, intersect_gyyd, xzfx, sqsx, ghfx
  28. prompt_lan = "CN"
  29. llm_name = "qwen-plus"
  30. llm_turbo_name = "gpt-4-turbo"
  31. max_ref_token = 4000
  32. # model_server = "http://10.10.0.10:7907/v1"
  33. # model_server = "http://lq.lianqiai.cn:7905/v1"
  34. # model_server = "http://172.20.28.16:20331/v1"
  35. model_server = "http://ac.zjugis.com:8511/v1"
  36. api_key = ""
  37. server_host = "0.0.0.0"
  38. server_port = 8511
  39. app = FastAPI()
  40. app.add_middleware(
  41. CORSMiddleware,
  42. allow_origins=["*"],
  43. allow_credentials=True,
  44. allow_methods=["*"],
  45. allow_headers=["*"],
  46. )
  47. rspHeaders = {
  48. "Cache-Control": "no-cache",
  49. "Connection": "keep-alive",
  50. "Content-Type": "text/event-stream",
  51. "Transfer-Encoding": "chunked",
  52. }
  53. if model_server.startswith("http"):
  54. source = "local"
  55. elif model_server.startswith("dashscope"):
  56. source = "dashscope"
  57. def rm_file(file_path):
  58. if os.path.exists(file_path):
  59. os.remove(file_path)
  60. @app.post("/")
  61. def index():
  62. return "Welcome to Lianqi AI"
  63. # 空间分析助手接口
  64. @app.post("/subscribe/spatial")
  65. async def upload_file(question: str = Form(...), history: list = Form(...), file: UploadFile = File(...)):
  66. print(f'history: {history}')
  67. # 确保上传的文件是ZIP类型
  68. if file.content_type != "application/zip":
  69. return {"detail": "请上传ZIP文件"}
  70. # 将文件写入临时文件
  71. temp_dir = tempfile.TemporaryDirectory(dir=os.path.dirname(__file__) + "/upload_temps")
  72. with open(f"{temp_dir.name}/{file.filename}", "wb") as buffer:
  73. shutil.copyfileobj(file.file, buffer)
  74. print(buffer)
  75. # 解压缩
  76. with ZipFile(f"{temp_dir.name}/{file.filename}", "r") as zip_ref:
  77. zip_ref.extractall(temp_dir.name)
  78. shp_file_path = ''
  79. for root, dirs, files in os.walk(temp_dir.name):
  80. for file in files:
  81. if file.endswith('.shp'):
  82. shp_file_path = os.path.join(root, file)
  83. geoms = GeometryParser.parse_geom_shp_file(shp_file_path)
  84. return EventSourceResponse(
  85. call_with_stream(f"原问题为:{question},上传的图形信息为: {Geometry.dumps_json(geoms)}"),
  86. media_type="text/event-stream",
  87. headers=rspHeaders,
  88. )
  89. # 连续对话
  90. @app.post("/subscribe/history", response_model=str)
  91. async def subscribe_with_history(request: BaseRequest):
  92. print(request)
  93. return EventSourceResponse(
  94. call_with_stream(request.data, request.history),
  95. media_type="text/event-stream",
  96. headers=rspHeaders,
  97. )
  98. @app.post("/subscribe/", response_model=str)
  99. async def subscribe(request: BaseRequest):
  100. print(request)
  101. return EventSourceResponse(
  102. call_with_stream(request.data, request.history),
  103. media_type="text/event-stream",
  104. headers=rspHeaders,
  105. )
  106. @app.get("/subscribe/{question}", response_model=str)
  107. async def subscribe(question: str):
  108. return EventSourceResponse(
  109. call_with_stream(question),
  110. media_type="text/event-stream",
  111. headers=rspHeaders,
  112. )
  113. @app.post("/subscribeByTurbo/", response_model=str)
  114. async def subscribeByTurbo(question: BaseRequest):
  115. return EventSourceResponse(
  116. call_with_stream(question.data, False, LLMDict_GPT4_TURBO),
  117. media_type="text/event-stream",
  118. headers=rspHeaders,
  119. )
  120. @app.get("/subscribeByTurbo/{question}", response_model=str)
  121. async def subscribeByTurbo(question: str):
  122. return EventSourceResponse(
  123. call_with_stream(question, False, LLMDict_GPT4_TURBO),
  124. media_type="text/event-stream",
  125. headers=rspHeaders,
  126. )
  127. @app.post("/clarification/", response_model=str)
  128. async def clarification(request: BaseRequest):
  129. print("clarification: ", request)
  130. return EventSourceResponse(
  131. call_with_stream(request.data, True),
  132. media_type="text/event-stream",
  133. headers=rspHeaders,
  134. )
  135. @app.post("/clarificationByTurbo/", response_model=str)
  136. async def clarificationByTurbo(request: BaseRequest):
  137. print("clarificationByTurbo: ", request)
  138. return EventSourceResponse(
  139. call_with_stream(request.data, True, LLMDict_GPT4_TURBO),
  140. media_type="text/event-stream",
  141. headers=rspHeaders,
  142. )
  143. @app.get("/kgQuery")
  144. async def kgQuery(id: str):
  145. sql = f'select id, xzqmc, xzqdm, dymc, yddm, ydxz, ydmj, rjlsx, rjlxx, jzmdsx, jzmdxx, jzgdsx, jzgdxx, ldlsx, ldlxx, pfwh, pfsj, st_area(shape::geography) as pfmarea,st_astext(shape) as geom, st_astext(st_centroid(shape)) as center_wkt from sde.kzxxxgh where id in ({id})'
  146. res_tuples = await async_db.run(sql)
  147. result, success = res_tuples
  148. return json.loads(result)
  149. @app.get("/klyzyQuery")
  150. async def klyzyQuery(id: str):
  151. sql = f'select *, st_astext(shape) as geom, st_astext(st_centroid(shape)) as center_wkt from sde.ecgap_klyzy where id in ({id})'
  152. res_tuples = await async_db.run(sql)
  153. result, success = res_tuples
  154. return json.loads(result)
  155. @app.get("/yjjbntQuery")
  156. async def yjjbntQuery(id: str):
  157. sql = f'select *,st_astext(shape) as geom from ddd.gcs330000g2001_yjjbnt_gx_xsb where objectid in ({id})'
  158. res_tuples = await async_xzdb.run(sql)
  159. result, success = res_tuples
  160. return json.loads(result)
  161. @app.get("/kfqintersect")
  162. async def kfqintersect(wkt: str):
  163. result = await intersect_kfq(
  164. wkt)
  165. return result
  166. @app.get("/gyydintersect")
  167. async def gyydintersect(wkt: str):
  168. result = await intersect_gyyd(wkt)
  169. return result
  170. llm_client = LLMClient(model=llm_name, model_server=model_server)
  171. llm_client_async = LLMAsyncClient(model=llm_name, model_server=model_server)
  172. async def call_with_stream(
  173. question,
  174. history=[],
  175. isClarification=False,
  176. llm_dict=LLMDict_Qwen_72B_1211,
  177. ):
  178. for i, msg in enumerate(history):
  179. if not isinstance(msg, dict):
  180. msg = dict(msg)
  181. if msg['type'].value == 1:
  182. msg['role'] = 'user'
  183. msg['content'] = msg['data']
  184. else:
  185. msg['role'] = 'assistant'
  186. if isinstance(msg['data'], str):
  187. msg['content'] = msg['data']
  188. else:
  189. msg['content'] = dict(msg['data'])['exec_res'][0]
  190. msg['history'] = True
  191. del msg['data']
  192. del msg['type']
  193. history[i] = msg
  194. if isClarification:
  195. executor = PlanContinueExecutor(
  196. llm_dict=llm_dict, llm=llm_client_async, stream=True
  197. )
  198. else:
  199. executor = PlanExecutor(llm_dict=llm_dict, llm=llm_client_async, stream=True)
  200. async for rsp in executor.run(question, history):
  201. if not rsp:
  202. continue
  203. else:
  204. time.sleep(0.1)
  205. yield rsp
  206. yield "[DONE]"
  207. yield "[FINISH]"
  208. if __name__ == "__main__":
  209. # uvicorn.run("run_server_async:app", host=server_host, port=server_port, workers=5)
  210. uvicorn.run(app, host=server_host, port=server_port, workers=1)