api.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import os
  2. import time
  3. import threading
  4. import json
  5. import logging
  6. from datetime import datetime
  7. from logging.handlers import RotatingFileHandler
  8. import cv2
  9. import numpy as np
  10. from fastapi import FastAPI, HTTPException, Request
  11. from starlette.concurrency import iterate_in_threadpool
  12. import uvicorn
  13. from depth_common import (
  14. Settings,
  15. TemporalFilter,
  16. compute_roi_bounds,
  17. extract_depth_data,
  18. find_nearest_point,
  19. init_depth_pipeline,
  20. nearest_distance_in_roi,
  21. )
  22. from utils import frame_to_bgr_image
  23. # 采样参数
  24. SAMPLE_COUNT = 10
  25. FRAME_TIMEOUT_MS = 200
  26. SAMPLE_TIMEOUT_SEC = 8
  27. MAX_SAVED_IMAGES = int(os.getenv("MAX_SAVED_IMAGES", "1000"))
  28. # 从环境变量加载测量配置
  29. SETTINGS = Settings.from_env()
  30. app = FastAPI(title="Cargo Height API")
  31. request_logger = logging.getLogger("cargo_height.request")
  32. MAX_LOG_RESPONSE_LEN = 1000
  33. REQUEST_LOG_MAX_BYTES = int(os.getenv("REQUEST_LOG_MAX_BYTES", str(20 * 1024 * 1024)))
  34. REQUEST_LOG_BACKUP_COUNT = int(os.getenv("REQUEST_LOG_BACKUP_COUNT", "10"))
  35. def _setup_request_logger():
  36. log_dir = os.path.join(os.getcwd(), "Log")
  37. os.makedirs(log_dir, exist_ok=True)
  38. log_file = os.path.join(log_dir, "request.log")
  39. request_logger.setLevel(logging.INFO)
  40. request_logger.propagate = False
  41. if request_logger.handlers:
  42. return
  43. formatter = logging.Formatter(
  44. "%(asctime)s [%(levelname)s] %(name)s - %(message)s",
  45. "%Y-%m-%d %H:%M:%S",
  46. )
  47. file_handler = RotatingFileHandler(
  48. log_file,
  49. maxBytes=REQUEST_LOG_MAX_BYTES,
  50. backupCount=REQUEST_LOG_BACKUP_COUNT,
  51. encoding="utf-8",
  52. )
  53. file_handler.setLevel(logging.INFO)
  54. file_handler.setFormatter(formatter)
  55. stream_handler = logging.StreamHandler()
  56. stream_handler.setLevel(logging.INFO)
  57. stream_handler.setFormatter(formatter)
  58. request_logger.addHandler(file_handler)
  59. request_logger.addHandler(stream_handler)
  60. _setup_request_logger()
  61. @app.middleware("http")
  62. async def request_log_middleware(request: Request, call_next):
  63. request_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
  64. start = time.perf_counter()
  65. try:
  66. response = await call_next(request)
  67. except Exception:
  68. elapsed_ms = (time.perf_counter() - start) * 1000
  69. request_logger.exception(
  70. "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
  71. request_time,
  72. request.method,
  73. request.url.path,
  74. 500,
  75. elapsed_ms,
  76. "internal_error",
  77. )
  78. raise
  79. body = b""
  80. async for chunk in response.body_iterator:
  81. body += chunk
  82. response.body_iterator = iterate_in_threadpool(iter([body]))
  83. if not body:
  84. response_text = ""
  85. else:
  86. content_type = response.headers.get("content-type", "")
  87. if "application/json" in content_type:
  88. try:
  89. response_text = json.dumps(json.loads(body), ensure_ascii=False)
  90. except Exception:
  91. response_text = body.decode("utf-8", errors="replace")
  92. else:
  93. response_text = body.decode("utf-8", errors="replace")
  94. if len(response_text) > MAX_LOG_RESPONSE_LEN:
  95. response_text = response_text[:MAX_LOG_RESPONSE_LEN] + "...(truncated)"
  96. elapsed_ms = (time.perf_counter() - start) * 1000
  97. request_logger.info(
  98. "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
  99. request_time,
  100. request.method,
  101. request.url.path,
  102. response.status_code,
  103. elapsed_ms,
  104. response_text,
  105. )
  106. return response
  107. # 相机相关的全局状态(由锁保护)
  108. _pipeline = None
  109. _depth_intrinsics = None
  110. _temporal_filter = None
  111. _lock = threading.Lock()
  112. def _init_camera():
  113. # 延迟初始化相机,避免重复启动
  114. global _pipeline, _depth_intrinsics, _temporal_filter
  115. if _pipeline is not None:
  116. return
  117. try:
  118. pipeline, depth_intrinsics, _ = init_depth_pipeline()
  119. except Exception as exc:
  120. raise RuntimeError(f"Failed to init depth camera: {exc}") from exc
  121. _pipeline = pipeline
  122. _depth_intrinsics = depth_intrinsics
  123. _temporal_filter = TemporalFilter(alpha=0.5)
  124. def _shutdown_camera():
  125. # 关闭相机资源
  126. global _pipeline
  127. if _pipeline is None:
  128. return
  129. _pipeline.stop()
  130. _pipeline = None
  131. def _measure_once():
  132. # 单次采样:获取一帧并在 ROI 内计算最近距离
  133. frames = _pipeline.wait_for_frames(FRAME_TIMEOUT_MS)
  134. if frames is None:
  135. return None
  136. color_frame = frames.get_color_frame()
  137. depth_frame = frames.get_depth_frame()
  138. depth_data = extract_depth_data(depth_frame, SETTINGS, _temporal_filter)
  139. if depth_data is None:
  140. return None
  141. bounds = compute_roi_bounds(depth_data, _depth_intrinsics, SETTINGS)
  142. if bounds is None:
  143. return None
  144. x_start, x_end, y_start, y_end, center_distance = bounds
  145. roi = depth_data[y_start:y_end, x_start:x_end]
  146. nearest_distance = nearest_distance_in_roi(roi, SETTINGS)
  147. if nearest_distance is None:
  148. return None
  149. return {
  150. "nearest_distance": nearest_distance,
  151. "color_frame": color_frame,
  152. "depth_data": depth_data,
  153. "bounds": bounds,
  154. "center_distance": center_distance,
  155. }
  156. def _save_current_sample_images(sample):
  157. save_image_dir = os.path.join(os.getcwd(), "sample_images")
  158. os.makedirs(save_image_dir, exist_ok=True)
  159. now = time.localtime()
  160. time_str = time.strftime("%Y%m%d_%H%M%S", now)
  161. millis = int((time.time() % 1) * 1000)
  162. timestamp = f"{time_str}_{millis:03d}"
  163. color_frame = sample.get("color_frame")
  164. if color_frame is not None:
  165. color_image = frame_to_bgr_image(color_frame)
  166. if color_image is not None:
  167. color_height, color_width = color_image.shape[:2]
  168. color_file = os.path.join(
  169. save_image_dir,
  170. f"color_{color_width}x{color_height}_{timestamp}.png",
  171. )
  172. cv2.imwrite(color_file, color_image)
  173. depth_data = sample["depth_data"]
  174. x_start, x_end, y_start, y_end, center_distance = sample["bounds"]
  175. nearest_distance = sample["nearest_distance"]
  176. roi = depth_data[y_start:y_end, x_start:x_end]
  177. nearest_point = find_nearest_point(roi, x_start, y_start, SETTINGS, nearest_distance)
  178. depth_image = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
  179. depth_image = cv2.applyColorMap(depth_image, cv2.COLORMAP_JET)
  180. cv2.rectangle(
  181. depth_image,
  182. (x_start, y_start),
  183. (x_end - 1, y_end - 1),
  184. (0, 255, 0),
  185. 2,
  186. )
  187. if nearest_point is not None:
  188. cv2.circle(depth_image, nearest_point, 4, (0, 0, 0), -1)
  189. cv2.circle(depth_image, nearest_point, 6, (0, 255, 255), 2)
  190. cv2.putText(
  191. depth_image,
  192. f"nearest: {nearest_distance} mm",
  193. (10, 30),
  194. cv2.FONT_HERSHEY_SIMPLEX,
  195. 0.8,
  196. (255, 255, 255),
  197. 2,
  198. cv2.LINE_AA,
  199. )
  200. cv2.putText(
  201. depth_image,
  202. f"center: {int(center_distance)} mm",
  203. (10, 60),
  204. cv2.FONT_HERSHEY_SIMPLEX,
  205. 0.8,
  206. (255, 255, 255),
  207. 2,
  208. cv2.LINE_AA,
  209. )
  210. depth_h, depth_w = depth_image.shape[:2]
  211. depth_file = os.path.join(
  212. save_image_dir,
  213. f"depth_annotated_{depth_w}x{depth_h}_{timestamp}.png",
  214. )
  215. cv2.imwrite(depth_file, depth_image)
  216. _prune_saved_images(save_image_dir, MAX_SAVED_IMAGES)
  217. def _prune_saved_images(save_dir, max_images):
  218. png_files = [
  219. os.path.join(save_dir, name)
  220. for name in os.listdir(save_dir)
  221. if name.lower().endswith(".png")
  222. ]
  223. if len(png_files) <= max_images:
  224. return
  225. png_files.sort(key=os.path.getmtime)
  226. for file_path in png_files[: len(png_files) - max_images]:
  227. try:
  228. os.remove(file_path)
  229. except OSError:
  230. pass
  231. @app.on_event("startup")
  232. def on_startup():
  233. # 服务启动时初始化相机
  234. _init_camera()
  235. @app.on_event("shutdown")
  236. def on_shutdown():
  237. # 服务关闭时释放相机
  238. _shutdown_camera()
  239. @app.get("/height")
  240. def get_height():
  241. # 采集多次样本并返回中位数高度
  242. start_time = time.time()
  243. samples = []
  244. first_valid_sample = None
  245. first_color_frame = None
  246. with _lock:
  247. while len(samples) < SAMPLE_COUNT and (time.time() - start_time) < SAMPLE_TIMEOUT_SEC:
  248. sample = _measure_once()
  249. if sample is not None:
  250. samples.append(sample["nearest_distance"])
  251. if first_valid_sample is None:
  252. first_valid_sample = sample
  253. if first_color_frame is None and sample.get("color_frame") is not None:
  254. first_color_frame = sample.get("color_frame")
  255. # If no color frame arrived during valid depth sampling, try a few extra pulls.
  256. if first_color_frame is None:
  257. for _ in range(5):
  258. frames = _pipeline.wait_for_frames(FRAME_TIMEOUT_MS)
  259. if frames is None:
  260. continue
  261. color_frame = frames.get_color_frame()
  262. if color_frame is not None:
  263. first_color_frame = color_frame
  264. break
  265. if first_valid_sample is not None:
  266. if first_color_frame is not None:
  267. first_valid_sample["color_frame"] = first_color_frame
  268. _save_current_sample_images(first_valid_sample)
  269. if len(samples) < SAMPLE_COUNT:
  270. raise HTTPException(status_code=503, detail="Insufficient valid samples from depth camera")
  271. median_value = int(np.median(np.array(samples, dtype=np.int32)))
  272. return {
  273. "height_mm": median_value,
  274. "samples": samples,
  275. "unit": "mm",
  276. "sample_count": SAMPLE_COUNT,
  277. }
  278. @app.get("/health")
  279. def health():
  280. # 健康检查接口
  281. return {"status": "ok"}
  282. def main():
  283. # 读取监听地址并启动 API 服务
  284. host = os.getenv("API_HOST", "127.0.0.1")
  285. port = int(os.getenv("API_PORT", "8080"))
  286. uvicorn.run("api:app", host=host, port=port, log_level="info")
  287. if __name__ == "__main__":
  288. main()