request_logging.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. """请求日志中间件。
  2. 记录 FastAPI 每个请求的路径、状态码、耗时和响应摘要,支持日志轮转。
  3. """
  4. import json
  5. import logging
  6. import os
  7. import time
  8. from datetime import datetime
  9. from logging.handlers import RotatingFileHandler
  10. from fastapi import FastAPI, Request
  11. from starlette.concurrency import iterate_in_threadpool
  12. def setup_request_logging(app: FastAPI, max_response_len: int, max_bytes: int, backup_count: int) -> None:
  13. """为应用安装请求日志中间件。"""
  14. logger = logging.getLogger("cargo_height.request")
  15. _setup_request_logger(logger, max_bytes=max_bytes, backup_count=backup_count)
  16. @app.middleware("http")
  17. async def request_log_middleware(request: Request, call_next):
  18. # 记录请求开始时间,用于后续耗时统计。
  19. request_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
  20. start = time.perf_counter()
  21. try:
  22. response = await call_next(request)
  23. except Exception:
  24. # 异常请求按 500 记录并保留堆栈。
  25. elapsed_ms = (time.perf_counter() - start) * 1000
  26. logger.exception(
  27. "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
  28. request_time,
  29. request.method,
  30. request.url.path,
  31. 500,
  32. elapsed_ms,
  33. "internal_error",
  34. )
  35. raise
  36. # 读取响应体用于日志输出,然后恢复迭代器,避免影响客户端接收数据。
  37. body = b""
  38. async for chunk in response.body_iterator:
  39. body += chunk
  40. response.body_iterator = iterate_in_threadpool(iter([body]))
  41. response_text = _parse_response_text(body, response.headers.get("content-type", ""))
  42. if len(response_text) > max_response_len:
  43. response_text = response_text[:max_response_len] + "...(truncated)"
  44. elapsed_ms = (time.perf_counter() - start) * 1000
  45. logger.info(
  46. "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
  47. request_time,
  48. request.method,
  49. request.url.path,
  50. response.status_code,
  51. elapsed_ms,
  52. response_text,
  53. )
  54. return response
  55. def _setup_request_logger(logger: logging.Logger, max_bytes: int, backup_count: int) -> None:
  56. """初始化请求日志记录器(文件 + 控制台)。"""
  57. log_dir = os.path.join(os.getcwd(), "Log")
  58. os.makedirs(log_dir, exist_ok=True)
  59. log_file = os.path.join(log_dir, "request.log")
  60. logger.setLevel(logging.INFO)
  61. logger.propagate = False
  62. if logger.handlers:
  63. return
  64. formatter = logging.Formatter(
  65. "%(asctime)s [%(levelname)s] %(name)s - %(message)s",
  66. "%Y-%m-%d %H:%M:%S",
  67. )
  68. file_handler = RotatingFileHandler(
  69. log_file,
  70. maxBytes=max_bytes,
  71. backupCount=backup_count,
  72. encoding="utf-8",
  73. )
  74. file_handler.setLevel(logging.INFO)
  75. file_handler.setFormatter(formatter)
  76. stream_handler = logging.StreamHandler()
  77. stream_handler.setLevel(logging.INFO)
  78. stream_handler.setFormatter(formatter)
  79. logger.addHandler(file_handler)
  80. logger.addHandler(stream_handler)
  81. def _parse_response_text(body: bytes, content_type: str) -> str:
  82. """根据响应类型提取可读日志文本。"""
  83. if not body:
  84. return ""
  85. if "application/json" in content_type:
  86. try:
  87. # 统一 JSON 格式,便于检索。
  88. return json.dumps(json.loads(body), ensure_ascii=False)
  89. except Exception:
  90. return body.decode("utf-8", errors="replace")
  91. return body.decode("utf-8", errors="replace")