main.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. AI Agent
  3. 使用LangChain v1.1.0、FastAPI和OpenAI模型,支持流式输出
  4. """
  5. import json
  6. import os
  7. from typing import AsyncIterator
  8. from dotenv import load_dotenv
  9. from fastapi import FastAPI
  10. from pydantic import BaseModel
  11. from langchain_openai import ChatOpenAI
  12. from langchain_core.messages import HumanMessage
  13. from sse_starlette.sse import EventSourceResponse
  14. from langchain.agents import create_agent
  15. from fastapi.middleware.cors import CORSMiddleware
  16. # 加载环境变量
  17. load_dotenv()
  18. # 初始化FastAPI应用
  19. app = FastAPI(title="AI Agent", version="1.1.0")
  20. # 添加CORS中间件
  21. app.add_middleware(
  22. CORSMiddleware,
  23. allow_origins=["*"], # 在生产环境中应该设置具体的域名
  24. allow_credentials=True,
  25. allow_methods=["*"],
  26. allow_headers=["*"],
  27. )
  28. # 配置OpenAI模型
  29. # 需要设置环境变量 OPENAI_API_KEY
  30. # 可以通过 os.environ["OPENAI_API_KEY"] = "your-api-key" 设置
  31. # 或者通过 .env 文件配置
  32. openai_model = ChatOpenAI(
  33. api_key=os.getenv('QWEN_API_KEY'),
  34. base_url=os.getenv("OPENAI_BASE_URL"),
  35. model=os.getenv("OPENAI_MODEL"),
  36. temperature=0.7,
  37. streaming=True,
  38. )
  39. # create_agent 会根据提供的工具创建可调用的智能体
  40. agent = create_agent(
  41. model=openai_model,
  42. tools=[],
  43. system_prompt="你是一名专业的AI助手,善于调用工具回答与时间和项目相关的问题。",
  44. )
  45. class ChatRequest(BaseModel):
  46. """聊天请求模型"""
  47. message: str
  48. conversation_id: str = "default"
  49. class ChatResponse(BaseModel):
  50. """聊天响应模型"""
  51. content: str
  52. conversation_id: str
  53. @app.get("/")
  54. async def root():
  55. """根路径"""
  56. return {
  57. "message": "AI Agent",
  58. "version": "0.0.1"
  59. }
  60. @app.post("/api/agent/chat", response_model=ChatResponse)
  61. async def chat(request: ChatRequest):
  62. """非流式聊天接口"""
  63. try:
  64. # 创建消息并交给Agent
  65. messages = [HumanMessage(content=request.message)]
  66. response = await agent.ainvoke({"messages": messages})
  67. print(f"结果: {response}")
  68. return ChatResponse(
  69. content=response["messages"][-1].content,
  70. conversation_id=request.conversation_id
  71. )
  72. except Exception as e:
  73. return ChatResponse(
  74. content=f"错误: {str(e)}",
  75. conversation_id=request.conversation_id
  76. )
  77. def _extract_text(content_block) -> str:
  78. """从LangChain内容块中提取纯文本"""
  79. if content_block is None:
  80. return ""
  81. if isinstance(content_block, str):
  82. return content_block
  83. # LangChain 可能返回 list[dict] 或 list[ContentBlock]
  84. text_parts = []
  85. if isinstance(content_block, list):
  86. for block in content_block:
  87. if isinstance(block, str):
  88. text_parts.append(block)
  89. elif isinstance(block, dict):
  90. if block.get("type") == "text":
  91. text_parts.append(block.get("text", ""))
  92. else:
  93. block_text = getattr(block, "text", None)
  94. if block_text:
  95. text_parts.append(block_text)
  96. else:
  97. block_text = getattr(content_block, "text", None)
  98. if block_text:
  99. text_parts.append(block_text)
  100. return "".join(text_parts)
  101. async def generate_stream(request: ChatRequest) -> AsyncIterator[str]:
  102. """生成流式响应"""
  103. try:
  104. # 创建消息并流式执行Agent
  105. messages = [HumanMessage(content=request.message)]
  106. async for token, metadata in agent.astream(
  107. {"messages": messages},
  108. stream_mode="messages"
  109. ):
  110. text = _extract_text(getattr(token, "content", None))
  111. if not text:
  112. text = _extract_text(getattr(token, "content_blocks", None))
  113. if not text:
  114. continue
  115. payload = json.dumps({"type": "text", "text": text}, ensure_ascii=False)
  116. yield payload
  117. # 发送结束标记
  118. yield "[DONE]"
  119. except Exception as e:
  120. error_payload = json.dumps({"type": "error", "text": str(e)}, ensure_ascii=False)
  121. yield error_payload
  122. yield "[DONE]"
  123. @app.post("/api/agent/chat/stream", response_description='{"type": "text", "text": "你好"}')
  124. async def chat_stream(request: ChatRequest):
  125. """流式聊天接口"""
  126. return EventSourceResponse(generate_stream(request))
  127. if __name__ == "__main__":
  128. import uvicorn
  129. # 检查API密钥
  130. if not os.getenv("OPENAI_API_KEY"):
  131. print("警告: 未设置 OPENAI_API_KEY 环境变量")
  132. print("请设置环境变量或创建 .env 文件")
  133. # 启动服务器
  134. # 注意: 如果直接运行仍有问题,建议使用命令行: uvicorn main:app --host 0.0.0.0 --port 8080
  135. try:
  136. uvicorn.run(
  137. "main:app", # 使用字符串形式更兼容
  138. host="127.0.0.1",
  139. port=8080,
  140. log_level="info",
  141. reload=False # 禁用reload以避免某些兼容性问题
  142. )
  143. except Exception as e:
  144. print(f"启动错误: {e}")
  145. print("\n建议使用命令行启动:")
  146. print("uvicorn main:app --host 0.0.0.0 --port 8080")