import json import os from contextlib import asynccontextmanager from datetime import datetime from functools import partial from typing import Iterator, Dict import time import requests from config import LLM_SEARCH_HOST import uvicorn from pydantic import BaseModel from fastapi import FastAPI, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from config import UserSearchRequest, SERVER_PORT from src.common.logger import get_logger from src.rag_chat import kb_chat, chat_regenerate, chat_continue_ask from src.common_chat import chat_question_recommend error_message_dict: Dict[int, str] = { 451: "您的问题涉嫌违规,暂时不支持回答", 500: "啊,难到我了!再试试别的吧" } logger = get_logger(__name__) start_time: datetime = datetime.now() def init(): pass @asynccontextmanager async def lifespan(_: FastAPI): init() yield app = FastAPI( lifespan=lifespan, title="AI Search Service", description="Anything you want in one search") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) class ChatQuery(BaseModel): query: str # RAG提问 @app.post("/chat/kb_chat") def chat_kb_chat(chat_query: ChatQuery): return StreamingResponse(kb_chat(chat_query.query), media_type="application/octet-stream") # 重新生成问题 @app.post("/chat/regenerate") def regenerate(chat_query: ChatQuery): return StreamingResponse(chat_regenerate(chat_query.query), media_type="application/octet-stream") # 追问 @app.post("/chat/continue_ask") def continue_ask(chat_query: ChatQuery): return StreamingResponse(chat_continue_ask(chat_query.query), media_type="application/octet-stream") # 问题推荐 @app.post("/chat/question_recommend") def question_recommend(chat_query: ChatQuery): return StreamingResponse(chat_question_recommend(chat_query.query), media_type="application/octet-stream") # healthcheck @app.get("/api/health") async def health(): return { "status": 200, "message": "health", } def main(): uvicorn.run( 'main:app', host='0.0.0.0', port=int(os.getenv("PORT", SERVER_PORT)), workers=int(os.getenv("WORKERS", 1)) ) if __name__ == '__main__': main()