main.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import json
  2. import os
  3. from contextlib import asynccontextmanager
  4. from datetime import datetime
  5. from functools import partial
  6. from typing import Iterator, Dict
  7. import time
  8. import requests
  9. from config import LLM_SEARCH_HOST
  10. import uvicorn
  11. from pydantic import BaseModel
  12. from fastapi import FastAPI, Request, status
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from fastapi.responses import JSONResponse, Response, StreamingResponse
  15. from config import UserSearchRequest, SERVER_PORT
  16. from src.common.logger import get_logger
  17. from src.rag_chat import kb_chat, chat_regenerate, chat_continue_ask
  18. from src.common_chat import chat_question_recommend
  19. error_message_dict: Dict[int, str] = {
  20. 451: "您的问题涉嫌违规,暂时不支持回答",
  21. 500: "啊,难到我了!再试试别的吧"
  22. }
  23. logger = get_logger(__name__)
  24. start_time: datetime = datetime.now()
  25. def init():
  26. pass
  27. @asynccontextmanager
  28. async def lifespan(_: FastAPI):
  29. init()
  30. yield
  31. app = FastAPI(
  32. lifespan=lifespan, title="AI Search Service", description="Anything you want in one search")
  33. app.add_middleware(
  34. CORSMiddleware,
  35. allow_origins=["*"],
  36. allow_credentials=True,
  37. allow_methods=["*"],
  38. allow_headers=["*"]
  39. )
  40. class ChatQuery(BaseModel):
  41. query: str
  42. # RAG提问
  43. @app.post("/chat/kb_chat")
  44. def chat_kb_chat(chat_query: ChatQuery):
  45. return StreamingResponse(kb_chat(chat_query.query), media_type="application/octet-stream")
  46. # 重新生成问题
  47. @app.post("/chat/regenerate")
  48. def regenerate(chat_query: ChatQuery):
  49. return StreamingResponse(chat_regenerate(chat_query.query), media_type="application/octet-stream")
  50. # 追问
  51. @app.post("/chat/continue_ask")
  52. def continue_ask(chat_query: ChatQuery):
  53. return StreamingResponse(chat_continue_ask(chat_query.query), media_type="application/octet-stream")
  54. # 问题推荐
  55. @app.post("/chat/question_recommend")
  56. def question_recommend(chat_query: ChatQuery):
  57. return StreamingResponse(chat_question_recommend(chat_query.query), media_type="application/octet-stream")
  58. # healthcheck
  59. @app.get("/api/health")
  60. async def health():
  61. return {
  62. "status": 200,
  63. "message": "health",
  64. }
  65. def main():
  66. uvicorn.run(
  67. 'main:app',
  68. host='0.0.0.0',
  69. port=int(os.getenv("PORT", SERVER_PORT)),
  70. workers=int(os.getenv("WORKERS", 1))
  71. )
  72. if __name__ == '__main__':
  73. main()