request_logging.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import json
  2. import logging
  3. import os
  4. import time
  5. from datetime import datetime
  6. from logging.handlers import RotatingFileHandler
  7. from fastapi import FastAPI, Request
  8. from starlette.concurrency import iterate_in_threadpool
  9. def setup_request_logging(app: FastAPI, max_response_len: int, max_bytes: int, backup_count: int) -> None:
  10. logger = logging.getLogger("cargo_height.request")
  11. _setup_request_logger(logger, max_bytes=max_bytes, backup_count=backup_count)
  12. @app.middleware("http")
  13. async def request_log_middleware(request: Request, call_next):
  14. request_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
  15. start = time.perf_counter()
  16. try:
  17. response = await call_next(request)
  18. except Exception:
  19. elapsed_ms = (time.perf_counter() - start) * 1000
  20. logger.exception(
  21. "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
  22. request_time,
  23. request.method,
  24. request.url.path,
  25. 500,
  26. elapsed_ms,
  27. "internal_error",
  28. )
  29. raise
  30. body = b""
  31. async for chunk in response.body_iterator:
  32. body += chunk
  33. response.body_iterator = iterate_in_threadpool(iter([body]))
  34. response_text = _parse_response_text(body, response.headers.get("content-type", ""))
  35. if len(response_text) > max_response_len:
  36. response_text = response_text[:max_response_len] + "...(truncated)"
  37. elapsed_ms = (time.perf_counter() - start) * 1000
  38. logger.info(
  39. "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
  40. request_time,
  41. request.method,
  42. request.url.path,
  43. response.status_code,
  44. elapsed_ms,
  45. response_text,
  46. )
  47. return response
  48. def _setup_request_logger(logger: logging.Logger, max_bytes: int, backup_count: int) -> None:
  49. log_dir = os.path.join(os.getcwd(), "Log")
  50. os.makedirs(log_dir, exist_ok=True)
  51. log_file = os.path.join(log_dir, "request.log")
  52. logger.setLevel(logging.INFO)
  53. logger.propagate = False
  54. if logger.handlers:
  55. return
  56. formatter = logging.Formatter(
  57. "%(asctime)s [%(levelname)s] %(name)s - %(message)s",
  58. "%Y-%m-%d %H:%M:%S",
  59. )
  60. file_handler = RotatingFileHandler(
  61. log_file,
  62. maxBytes=max_bytes,
  63. backupCount=backup_count,
  64. encoding="utf-8",
  65. )
  66. file_handler.setLevel(logging.INFO)
  67. file_handler.setFormatter(formatter)
  68. stream_handler = logging.StreamHandler()
  69. stream_handler.setLevel(logging.INFO)
  70. stream_handler.setFormatter(formatter)
  71. logger.addHandler(file_handler)
  72. logger.addHandler(stream_handler)
  73. def _parse_response_text(body: bytes, content_type: str) -> str:
  74. if not body:
  75. return ""
  76. if "application/json" in content_type:
  77. try:
  78. return json.dumps(json.loads(body), ensure_ascii=False)
  79. except Exception:
  80. return body.decode("utf-8", errors="replace")
  81. return body.decode("utf-8", errors="replace")