123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import json
- import os
- from contextlib import asynccontextmanager
- from datetime import datetime
- from functools import partial
- from typing import Iterator, Dict
- import time
- import requests
- from config import LLM_SEARCH_HOST
- import uvicorn
- from pydantic import BaseModel
- from fastapi import FastAPI, Request, status
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse, Response, StreamingResponse
- from config import UserSearchRequest, SERVER_PORT
- from src.common.logger import get_logger
- from src.rag_chat import kb_chat, chat_regenerate, chat_continue_ask
- from src.common_chat import chat_question_recommend
- error_message_dict: Dict[int, str] = {
- 451: "您的问题涉嫌违规,暂时不支持回答",
- 500: "啊,难到我了!再试试别的吧"
- }
- logger = get_logger(__name__)
- start_time: datetime = datetime.now()
- def init():
- pass
- @asynccontextmanager
- async def lifespan(_: FastAPI):
- init()
- yield
- app = FastAPI(
- lifespan=lifespan, title="AI Search Service", description="Anything you want in one search")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"]
- )
- class ChatQuery(BaseModel):
- query: str
- # RAG提问
- @app.post("/chat/kb_chat")
- def chat_kb_chat(chat_query: ChatQuery):
- return StreamingResponse(kb_chat(chat_query.query), media_type="application/octet-stream")
- # 重新生成问题
- @app.post("/chat/regenerate")
- def regenerate(chat_query: ChatQuery):
- return StreamingResponse(chat_regenerate(chat_query.query), media_type="application/octet-stream")
- # 追问
- @app.post("/chat/continue_ask")
- def continue_ask(chat_query: ChatQuery):
- return StreamingResponse(chat_continue_ask(chat_query.query), media_type="application/octet-stream")
- # 问题推荐
- @app.post("/chat/question_recommend")
- def question_recommend(chat_query: ChatQuery):
- return StreamingResponse(chat_question_recommend(chat_query.query), media_type="application/octet-stream")
- # healthcheck
- @app.get("/api/health")
- async def health():
- return {
- "status": 200,
- "message": "health",
- }
- def main():
- uvicorn.run(
- 'main:app',
- host='0.0.0.0',
- port=int(os.getenv("PORT", SERVER_PORT)),
- workers=int(os.getenv("WORKERS", 1))
- )
- if __name__ == '__main__':
- main()
|