123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import time
- import sys
- from sse_starlette.sse import EventSourceResponse
- import importlib
- from fastapi import FastAPI
- from pydantic import BaseModel
- from fastapi.middleware.cors import CORSMiddleware
- import uvicorn
- import os
- parent_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
- sys.path.append(parent_dir)
- from qwen_agent.planning.plan_executor import PlanExecutor
- from qwen_agent.planning.plan_continue_executor import PlanContinueExecutor
- from qwen_agent.messages.context_message import SystemSignal
- prompt_lan = "CN"
- llm_name = "qwen-plus"
- max_ref_token = 4000
- model_server = "http://10.10.0.10:7909/v1"
- api_key = ""
- server_host = "127.0.0.1"
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- rspHeaders = {
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Transfer-Encoding": "chunked",
- }
- if model_server.startswith("http"):
- source = "local"
- elif model_server.startswith("dashscope"):
- source = "dashscope"
- if llm_name.startswith("gpt"):
- module = "qwen_agent.llm.gpt"
- llm = importlib.import_module(module).GPT(llm_name)
- elif llm_name.startswith("Qwen") or llm_name.startswith("qwen"):
- module = "qwen_agent.llm.qwen"
- llm = importlib.import_module(module).Qwen(
- llm_name, model_server=model_server, api_key=api_key
- )
- else:
- raise NotImplementedError
- planContinueExecutor = PlanContinueExecutor(enable_critic=False, llm=llm, stream=True)
- planExecutor = PlanExecutor(enable_critic=False, llm=llm, stream=True)
- @app.post("/")
- def index():
- return "Welcome to Lianqi AI"
- @app.post("/subscribe/{question}", response_model=str)
- async def subscribe(question: str):
- return EventSourceResponse(
- call_with_stream(question),
- media_type="text/event-stream",
- headers=rspHeaders,
- )
- @app.get("/subscribe/{question}", response_model=str)
- async def subscribe(question: str):
- return EventSourceResponse(
- call_with_stream(question),
- media_type="text/event-stream",
- headers=rspHeaders,
- )
- class ClarificationRequest(BaseModel):
- data: str
- @app.post("/clarification/", response_model=str)
- async def clarification(request: ClarificationRequest):
- print("clarification: ", request)
- return EventSourceResponse(
- call_with_stream(request.data, True),
- media_type="text/event-stream",
- headers=rspHeaders,
- )
- async def call_with_stream(question, isClarification=False):
- if isClarification:
- executor = planContinueExecutor
- else:
- executor = planExecutor
- for rsp in executor.run(question, []):
- yield f"{rsp}"
- yield "data: [DONE]"
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=20020, workers=10)
|