run_lianqi_server.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import time
  2. import sys
  3. from sse_starlette.sse import EventSourceResponse
  4. import importlib
  5. from fastapi import FastAPI
  6. from pydantic import BaseModel
  7. from fastapi.middleware.cors import CORSMiddleware
  8. import uvicorn
  9. import os
  10. parent_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
  11. sys.path.append(parent_dir)
  12. from qwen_agent.planning.plan_executor import PlanExecutor
  13. from qwen_agent.planning.plan_continue_executor import PlanContinueExecutor
  14. from qwen_agent.messages.context_message import SystemSignal
  15. prompt_lan = "CN"
  16. llm_name = "qwen-plus"
  17. max_ref_token = 4000
  18. model_server = "http://10.10.0.10:7909/v1"
  19. api_key = ""
  20. server_host = "127.0.0.1"
  21. app = FastAPI()
  22. app.add_middleware(
  23. CORSMiddleware,
  24. allow_origins=["*"],
  25. allow_credentials=True,
  26. allow_methods=["*"],
  27. allow_headers=["*"],
  28. )
  29. rspHeaders = {
  30. "Cache-Control": "no-cache",
  31. "Connection": "keep-alive",
  32. "Transfer-Encoding": "chunked",
  33. }
  34. if model_server.startswith("http"):
  35. source = "local"
  36. elif model_server.startswith("dashscope"):
  37. source = "dashscope"
  38. if llm_name.startswith("gpt"):
  39. module = "qwen_agent.llm.gpt"
  40. llm = importlib.import_module(module).GPT(llm_name)
  41. elif llm_name.startswith("Qwen") or llm_name.startswith("qwen"):
  42. module = "qwen_agent.llm.qwen"
  43. llm = importlib.import_module(module).Qwen(
  44. llm_name, model_server=model_server, api_key=api_key
  45. )
  46. else:
  47. raise NotImplementedError
  48. planContinueExecutor = PlanContinueExecutor(enable_critic=False, llm=llm, stream=True)
  49. planExecutor = PlanExecutor(enable_critic=False, llm=llm, stream=True)
  50. @app.post("/")
  51. def index():
  52. return "Welcome to Lianqi AI"
  53. @app.post("/subscribe/{question}", response_model=str)
  54. async def subscribe(question: str):
  55. return EventSourceResponse(
  56. call_with_stream(question),
  57. media_type="text/event-stream",
  58. headers=rspHeaders,
  59. )
  60. @app.get("/subscribe/{question}", response_model=str)
  61. async def subscribe(question: str):
  62. return EventSourceResponse(
  63. call_with_stream(question),
  64. media_type="text/event-stream",
  65. headers=rspHeaders,
  66. )
  67. class ClarificationRequest(BaseModel):
  68. data: str
  69. @app.post("/clarification/", response_model=str)
  70. async def clarification(request: ClarificationRequest):
  71. print("clarification: ", request)
  72. return EventSourceResponse(
  73. call_with_stream(request.data, True),
  74. media_type="text/event-stream",
  75. headers=rspHeaders,
  76. )
  77. async def call_with_stream(question, isClarification=False):
  78. if isClarification:
  79. executor = planContinueExecutor
  80. else:
  81. executor = planExecutor
  82. for rsp in executor.run(question, []):
  83. yield f"{rsp}"
  84. yield "data: [DONE]"
  85. if __name__ == "__main__":
  86. uvicorn.run(app, host="0.0.0.0", port=20020, workers=10)