|
@@ -0,0 +1,99 @@
|
|
|
+import json
|
|
|
+import os
|
|
|
+from contextlib import asynccontextmanager
|
|
|
+from datetime import datetime
|
|
|
+from functools import partial
|
|
|
+from typing import Iterator, Dict
|
|
|
+import time
|
|
|
+import requests
|
|
|
+from config import LLM_SEARCH_HOST
|
|
|
+import uvicorn
|
|
|
+from pydantic import BaseModel
|
|
|
+from fastapi import FastAPI, Request, status
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
+from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
|
+
|
|
|
+from config import UserSearchRequest, SERVER_PORT
|
|
|
+from src.common.logger import get_logger
|
|
|
+from src.rag_chat import kb_chat, chat_regenerate, chat_continue_ask
|
|
|
+from src.common_chat import chat_question_recommend
|
|
|
+
|
|
|
+error_message_dict: Dict[int, str] = {
|
|
|
+ 451: "您的问题涉嫌违规,暂时不支持回答",
|
|
|
+ 500: "啊,难到我了!再试试别的吧"
|
|
|
+}
|
|
|
+
|
|
|
+logger = get_logger(__name__)
|
|
|
+start_time: datetime = datetime.now()
|
|
|
+
|
|
|
+
|
|
|
+def init():
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+@asynccontextmanager
|
|
|
+async def lifespan(_: FastAPI):
|
|
|
+ init()
|
|
|
+ yield
|
|
|
+
|
|
|
+
|
|
|
+app = FastAPI(
|
|
|
+ lifespan=lifespan, title="AI Search Service", description="Anything you want in one search")
|
|
|
+
|
|
|
+app.add_middleware(
|
|
|
+ CORSMiddleware,
|
|
|
+ allow_origins=["*"],
|
|
|
+ allow_credentials=True,
|
|
|
+ allow_methods=["*"],
|
|
|
+ allow_headers=["*"]
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+class ChatQuery(BaseModel):
|
|
|
+ query: str
|
|
|
+
|
|
|
+
|
|
|
+# RAG提问
|
|
|
+@app.post("/chat/kb_chat")
|
|
|
+def chat_kb_chat(chat_query: ChatQuery):
|
|
|
+ return StreamingResponse(kb_chat(chat_query.query), media_type="application/octet-stream")
|
|
|
+
|
|
|
+
|
|
|
+# 重新生成问题
|
|
|
+@app.post("/chat/regenerate")
|
|
|
+def regenerate(chat_query: ChatQuery):
|
|
|
+ return StreamingResponse(chat_regenerate(chat_query.query), media_type="application/octet-stream")
|
|
|
+
|
|
|
+
|
|
|
+# 追问
|
|
|
+@app.post("/chat/continue_ask")
|
|
|
+def continue_ask(chat_query: ChatQuery):
|
|
|
+ return StreamingResponse(chat_continue_ask(chat_query.query), media_type="application/octet-stream")
|
|
|
+
|
|
|
+
|
|
|
+# 问题推荐
|
|
|
+@app.post("/chat/question_recommend")
|
|
|
+def question_recommend(chat_query: ChatQuery):
|
|
|
+ return StreamingResponse(chat_question_recommend(chat_query.query), media_type="application/octet-stream")
|
|
|
+
|
|
|
+
|
|
|
+# healthcheck
|
|
|
+@app.get("/api/health")
|
|
|
+async def health():
|
|
|
+ return {
|
|
|
+ "status": 200,
|
|
|
+ "message": "health",
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ uvicorn.run(
|
|
|
+ 'main:app',
|
|
|
+ host='0.0.0.0',
|
|
|
+ port=int(os.getenv("PORT", SERVER_PORT)),
|
|
|
+ workers=int(os.getenv("WORKERS", 1))
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ main()
|