| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- """
- AI Agent
- 使用LangChain v1.1.0、FastAPI和OpenAI模型,支持流式输出
- """
- import json
- import os
- from typing import AsyncIterator
- from dotenv import load_dotenv
- from fastapi import FastAPI
- from pydantic import BaseModel
- from langchain_openai import ChatOpenAI
- from langchain_core.messages import HumanMessage
- from sse_starlette.sse import EventSourceResponse
- from langchain.agents import create_agent
- from fastapi.middleware.cors import CORSMiddleware
- # 加载环境变量
- load_dotenv()
- # 初始化FastAPI应用
- app = FastAPI(title="AI Agent", version="1.1.0")
- # 添加CORS中间件
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"], # 在生产环境中应该设置具体的域名
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # 配置OpenAI模型
- # 需要设置环境变量 OPENAI_API_KEY
- # 可以通过 os.environ["OPENAI_API_KEY"] = "your-api-key" 设置
- # 或者通过 .env 文件配置
- openai_model = ChatOpenAI(
- api_key=os.getenv('QWEN_API_KEY'),
- base_url=os.getenv("OPENAI_BASE_URL"),
- model=os.getenv("OPENAI_MODEL"),
- temperature=0.7,
- streaming=True,
- )
- # create_agent 会根据提供的工具创建可调用的智能体
- agent = create_agent(
- model=openai_model,
- tools=[],
- system_prompt="你是一名专业的AI助手,善于调用工具回答与时间和项目相关的问题。",
- )
- class ChatRequest(BaseModel):
- """聊天请求模型"""
- message: str
- conversation_id: str = "default"
- class ChatResponse(BaseModel):
- """聊天响应模型"""
- content: str
- conversation_id: str
- @app.get("/")
- async def root():
- """根路径"""
- return {
- "message": "AI Agent",
- "version": "0.0.1"
- }
- @app.post("/api/agent/chat", response_model=ChatResponse)
- async def chat(request: ChatRequest):
- """非流式聊天接口"""
- try:
- # 创建消息并交给Agent
- messages = [HumanMessage(content=request.message)]
- response = await agent.ainvoke({"messages": messages})
- print(f"结果: {response}")
- return ChatResponse(
- content=response["messages"][-1].content,
- conversation_id=request.conversation_id
- )
- except Exception as e:
- return ChatResponse(
- content=f"错误: {str(e)}",
- conversation_id=request.conversation_id
- )
- def _extract_text(content_block) -> str:
- """从LangChain内容块中提取纯文本"""
- if content_block is None:
- return ""
- if isinstance(content_block, str):
- return content_block
- # LangChain 可能返回 list[dict] 或 list[ContentBlock]
- text_parts = []
- if isinstance(content_block, list):
- for block in content_block:
- if isinstance(block, str):
- text_parts.append(block)
- elif isinstance(block, dict):
- if block.get("type") == "text":
- text_parts.append(block.get("text", ""))
- else:
- block_text = getattr(block, "text", None)
- if block_text:
- text_parts.append(block_text)
- else:
- block_text = getattr(content_block, "text", None)
- if block_text:
- text_parts.append(block_text)
- return "".join(text_parts)
- async def generate_stream(request: ChatRequest) -> AsyncIterator[str]:
- """生成流式响应"""
- try:
- # 创建消息并流式执行Agent
- messages = [HumanMessage(content=request.message)]
- async for token, metadata in agent.astream(
- {"messages": messages},
- stream_mode="messages"
- ):
- text = _extract_text(getattr(token, "content", None))
- if not text:
- text = _extract_text(getattr(token, "content_blocks", None))
- if not text:
- continue
- payload = json.dumps({"type": "text", "text": text}, ensure_ascii=False)
- yield payload
- # 发送结束标记
- yield "[DONE]"
- except Exception as e:
- error_payload = json.dumps({"type": "error", "text": str(e)}, ensure_ascii=False)
- yield error_payload
- yield "[DONE]"
- @app.post("/api/agent/chat/stream", response_description='{"type": "text", "text": "你好"}')
- async def chat_stream(request: ChatRequest):
- """流式聊天接口"""
- return EventSourceResponse(generate_stream(request))
- if __name__ == "__main__":
- import uvicorn
-
- # 检查API密钥
- if not os.getenv("OPENAI_API_KEY"):
- print("警告: 未设置 OPENAI_API_KEY 环境变量")
- print("请设置环境变量或创建 .env 文件")
-
- # 启动服务器
- # 注意: 如果直接运行仍有问题,建议使用命令行: uvicorn main:app --host 0.0.0.0 --port 8080
- try:
- uvicorn.run(
- "main:app", # 使用字符串形式更兼容
- host="127.0.0.1",
- port=8080,
- log_level="info",
- reload=False # 禁用reload以避免某些兼容性问题
- )
- except Exception as e:
- print(f"启动错误: {e}")
- print("\n建议使用命令行启动:")
- print("uvicorn main:app --host 0.0.0.0 --port 8080")
|