fengyanglei 1 kuukausi sitten
vanhempi
commit
e6616aa332
4 muutettua tiedostoa jossa 351 lisäystä ja 311 poistoa
  1. 21 311
      api.py
  2. 33 0
      api_config.py
  3. 200 0
      cargo_service.py
  4. 97 0
      request_logging.py

+ 21 - 311
api.py

@@ -1,337 +1,47 @@
-import os
-import time
-import threading
-import json
-import logging
-from datetime import datetime
-from logging.handlers import RotatingFileHandler
-
-import cv2
-import numpy as np
-from fastapi import FastAPI, HTTPException, Request
-from starlette.concurrency import iterate_in_threadpool
+from fastapi import FastAPI, HTTPException
 import uvicorn
 
-from depth_common import (
-    Settings,
-    TemporalFilter,
-    compute_roi_bounds,
-    extract_depth_data,
-    find_nearest_point,
-    init_depth_pipeline,
-    nearest_distance_in_roi,
-)
-from utils import frame_to_bgr_image
+from api_config import ApiConfig
+from cargo_service import CargoHeightService
+from request_logging import setup_request_logging
 
-# 采样参数
-SAMPLE_COUNT = 10
-FRAME_TIMEOUT_MS = 200
-SAMPLE_TIMEOUT_SEC = 8
-MAX_SAVED_IMAGES = int(os.getenv("MAX_SAVED_IMAGES", "1000"))
-# 从环境变量加载测量配置
-SETTINGS = Settings.from_env()
+config = ApiConfig.from_env()
+service = CargoHeightService(config)
 
 app = FastAPI(title="Cargo Height API")
-request_logger = logging.getLogger("cargo_height.request")
-MAX_LOG_RESPONSE_LEN = 1000
-REQUEST_LOG_MAX_BYTES = int(os.getenv("REQUEST_LOG_MAX_BYTES", str(20 * 1024 * 1024)))
-REQUEST_LOG_BACKUP_COUNT = int(os.getenv("REQUEST_LOG_BACKUP_COUNT", "10"))
-
-
-def _setup_request_logger():
-    log_dir = os.path.join(os.getcwd(), "Log")
-    os.makedirs(log_dir, exist_ok=True)
-    log_file = os.path.join(log_dir, "request.log")
-
-    request_logger.setLevel(logging.INFO)
-    request_logger.propagate = False
-
-    if request_logger.handlers:
-        return
-
-    formatter = logging.Formatter(
-        "%(asctime)s [%(levelname)s] %(name)s - %(message)s",
-        "%Y-%m-%d %H:%M:%S",
-    )
-
-    file_handler = RotatingFileHandler(
-        log_file,
-        maxBytes=REQUEST_LOG_MAX_BYTES,
-        backupCount=REQUEST_LOG_BACKUP_COUNT,
-        encoding="utf-8",
-    )
-    file_handler.setLevel(logging.INFO)
-    file_handler.setFormatter(formatter)
-
-    stream_handler = logging.StreamHandler()
-    stream_handler.setLevel(logging.INFO)
-    stream_handler.setFormatter(formatter)
-
-    request_logger.addHandler(file_handler)
-    request_logger.addHandler(stream_handler)
-
-
-_setup_request_logger()
-
-
-@app.middleware("http")
-async def request_log_middleware(request: Request, call_next):
-    request_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
-    start = time.perf_counter()
-    try:
-        response = await call_next(request)
-    except Exception:
-        elapsed_ms = (time.perf_counter() - start) * 1000
-        request_logger.exception(
-            "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
-            request_time,
-            request.method,
-            request.url.path,
-            500,
-            elapsed_ms,
-            "internal_error",
-        )
-        raise
-
-    body = b""
-    async for chunk in response.body_iterator:
-        body += chunk
-    response.body_iterator = iterate_in_threadpool(iter([body]))
-
-    if not body:
-        response_text = ""
-    else:
-        content_type = response.headers.get("content-type", "")
-        if "application/json" in content_type:
-            try:
-                response_text = json.dumps(json.loads(body), ensure_ascii=False)
-            except Exception:
-                response_text = body.decode("utf-8", errors="replace")
-        else:
-            response_text = body.decode("utf-8", errors="replace")
-    if len(response_text) > MAX_LOG_RESPONSE_LEN:
-        response_text = response_text[:MAX_LOG_RESPONSE_LEN] + "...(truncated)"
-
-    elapsed_ms = (time.perf_counter() - start) * 1000
-    request_logger.info(
-        "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
-        request_time,
-        request.method,
-        request.url.path,
-        response.status_code,
-        elapsed_ms,
-        response_text,
-    )
-    return response
-
-
-# 相机相关的全局状态(由锁保护)
-_pipeline = None
-_depth_intrinsics = None
-_temporal_filter = None
-_lock = threading.Lock()
-
-
-def _init_camera():
-    # 延迟初始化相机,避免重复启动
-    global _pipeline, _depth_intrinsics, _temporal_filter
-    if _pipeline is not None:
-        return
-    try:
-        pipeline, depth_intrinsics, _ = init_depth_pipeline()
-    except Exception as exc:
-        raise RuntimeError(f"Failed to init depth camera: {exc}") from exc
-    _pipeline = pipeline
-    _depth_intrinsics = depth_intrinsics
-    _temporal_filter = TemporalFilter(alpha=0.5)
-
-
-def _shutdown_camera():
-    # 关闭相机资源
-    global _pipeline
-    if _pipeline is None:
-        return
-    _pipeline.stop()
-    _pipeline = None
-
-
-def _measure_once():
-    # 单次采样:获取一帧并在 ROI 内计算最近距离
-    frames = _pipeline.wait_for_frames(FRAME_TIMEOUT_MS)
-    if frames is None:
-        return None
-    color_frame = frames.get_color_frame()
-    depth_frame = frames.get_depth_frame()
-    depth_data = extract_depth_data(depth_frame, SETTINGS, _temporal_filter)
-    if depth_data is None:
-        return None
-    bounds = compute_roi_bounds(depth_data, _depth_intrinsics, SETTINGS)
-    if bounds is None:
-        return None
-    x_start, x_end, y_start, y_end, center_distance = bounds
-    roi = depth_data[y_start:y_end, x_start:x_end]
-    nearest_distance = nearest_distance_in_roi(roi, SETTINGS)
-    if nearest_distance is None:
-        return None
-    return {
-        "nearest_distance": nearest_distance,
-        "color_frame": color_frame,
-        "depth_data": depth_data,
-        "bounds": bounds,
-        "center_distance": center_distance,
-    }
-
-
-def _save_current_sample_images(sample):
-    save_image_dir = os.path.join(os.getcwd(), "sample_images")
-    os.makedirs(save_image_dir, exist_ok=True)
-    now = time.localtime()
-    time_str = time.strftime("%Y%m%d_%H%M%S", now)
-    millis = int((time.time() % 1) * 1000)
-    timestamp = f"{time_str}_{millis:03d}"
-
-    color_frame = sample.get("color_frame")
-    if color_frame is not None:
-        color_image = frame_to_bgr_image(color_frame)
-        if color_image is not None:
-            color_height, color_width = color_image.shape[:2]
-            color_file = os.path.join(
-                save_image_dir,
-                f"color_{color_width}x{color_height}_{timestamp}.png",
-            )
-            cv2.imwrite(color_file, color_image)
-
-    depth_data = sample["depth_data"]
-    x_start, x_end, y_start, y_end, center_distance = sample["bounds"]
-    nearest_distance = sample["nearest_distance"]
-    roi = depth_data[y_start:y_end, x_start:x_end]
-    nearest_point = find_nearest_point(roi, x_start, y_start, SETTINGS, nearest_distance)
-
-    depth_image = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
-    depth_image = cv2.applyColorMap(depth_image, cv2.COLORMAP_JET)
-    cv2.rectangle(
-        depth_image,
-        (x_start, y_start),
-        (x_end - 1, y_end - 1),
-        (0, 255, 0),
-        2,
-    )
-    if nearest_point is not None:
-        cv2.circle(depth_image, nearest_point, 4, (0, 0, 0), -1)
-        cv2.circle(depth_image, nearest_point, 6, (0, 255, 255), 2)
-
-    cv2.putText(
-        depth_image,
-        f"nearest: {nearest_distance} mm",
-        (10, 30),
-        cv2.FONT_HERSHEY_SIMPLEX,
-        0.8,
-        (255, 255, 255),
-        2,
-        cv2.LINE_AA,
-    )
-    cv2.putText(
-        depth_image,
-        f"center: {int(center_distance)} mm",
-        (10, 60),
-        cv2.FONT_HERSHEY_SIMPLEX,
-        0.8,
-        (255, 255, 255),
-        2,
-        cv2.LINE_AA,
-    )
-
-    depth_h, depth_w = depth_image.shape[:2]
-    depth_file = os.path.join(
-        save_image_dir,
-        f"depth_annotated_{depth_w}x{depth_h}_{timestamp}.png",
-    )
-    cv2.imwrite(depth_file, depth_image)
-    _prune_saved_images(save_image_dir, MAX_SAVED_IMAGES)
-
-
-def _prune_saved_images(save_dir, max_images):
-    png_files = [
-        os.path.join(save_dir, name)
-        for name in os.listdir(save_dir)
-        if name.lower().endswith(".png")
-    ]
-    if len(png_files) <= max_images:
-        return
-    png_files.sort(key=os.path.getmtime)
-    for file_path in png_files[: len(png_files) - max_images]:
-        try:
-            os.remove(file_path)
-        except OSError:
-            pass
+setup_request_logging(
+    app,
+    max_response_len=config.request_log_max_len,
+    max_bytes=config.request_log_max_bytes,
+    backup_count=config.request_log_backup_count,
+)
 
 
 @app.on_event("startup")
-def on_startup():
-    # 服务启动时初始化相机
-    _init_camera()
+def on_startup() -> None:
+    service.startup()
 
 
 @app.on_event("shutdown")
-def on_shutdown():
-    # 服务关闭时释放相机
-    _shutdown_camera()
+def on_shutdown() -> None:
+    service.shutdown()
 
 
 @app.get("/height")
 def get_height():
-    # 采集多次样本并返回中位数高度
-    start_time = time.time()
-    samples = []
-    first_valid_sample = None
-    first_color_frame = None
-    with _lock:
-        while len(samples) < SAMPLE_COUNT and (time.time() - start_time) < SAMPLE_TIMEOUT_SEC:
-            sample = _measure_once()
-            if sample is not None:
-                samples.append(sample["nearest_distance"])
-                if first_valid_sample is None:
-                    first_valid_sample = sample
-                if first_color_frame is None and sample.get("color_frame") is not None:
-                    first_color_frame = sample.get("color_frame")
-
-        # If no color frame arrived during valid depth sampling, try a few extra pulls.
-        if first_color_frame is None:
-            for _ in range(5):
-                frames = _pipeline.wait_for_frames(FRAME_TIMEOUT_MS)
-                if frames is None:
-                    continue
-                color_frame = frames.get_color_frame()
-                if color_frame is not None:
-                    first_color_frame = color_frame
-                    break
-
-        if first_valid_sample is not None:
-            if first_color_frame is not None:
-                first_valid_sample["color_frame"] = first_color_frame
-            _save_current_sample_images(first_valid_sample)
-    if len(samples) < SAMPLE_COUNT:
+    result = service.measure_height()
+    if result is None:
         raise HTTPException(status_code=503, detail="Insufficient valid samples from depth camera")
-    median_value = int(np.median(np.array(samples, dtype=np.int32)))
-    return {
-        "height_mm": median_value,
-        "samples": samples,
-        "unit": "mm",
-        "sample_count": SAMPLE_COUNT,
-    }
+    return result
 
 
 @app.get("/health")
 def health():
-    # 健康检查接口
     return {"status": "ok"}
 
 
-def main():
-    # 读取监听地址并启动 API 服务
-    host = os.getenv("API_HOST", "127.0.0.1")
-    port = int(os.getenv("API_PORT", "8080"))
-    uvicorn.run("api:app", host=host, port=port, log_level="info")
+def main() -> None:
+    uvicorn.run("api:app", host=config.api_host, port=config.api_port, log_level="info")
 
 
 if __name__ == "__main__":

+ 33 - 0
api_config.py

@@ -0,0 +1,33 @@
+import os
+from dataclasses import dataclass
+
+from depth_common import Settings
+
+
+@dataclass(frozen=True)
+class ApiConfig:
+    sample_count: int
+    frame_timeout_ms: int
+    sample_timeout_sec: int
+    max_saved_images: int
+    request_log_max_len: int
+    request_log_max_bytes: int
+    request_log_backup_count: int
+    api_host: str
+    api_port: int
+    settings: Settings
+
+    @classmethod
+    def from_env(cls) -> "ApiConfig":
+        return cls(
+            sample_count=int(os.getenv("SAMPLE_COUNT", "10")),
+            frame_timeout_ms=int(os.getenv("FRAME_TIMEOUT_MS", "200")),
+            sample_timeout_sec=int(os.getenv("SAMPLE_TIMEOUT_SEC", "8")),
+            max_saved_images=int(os.getenv("MAX_SAVED_IMAGES", "1000")),
+            request_log_max_len=int(os.getenv("REQUEST_LOG_MAX_LEN", "1000")),
+            request_log_max_bytes=int(os.getenv("REQUEST_LOG_MAX_BYTES", str(20 * 1024 * 1024))),
+            request_log_backup_count=int(os.getenv("REQUEST_LOG_BACKUP_COUNT", "10")),
+            api_host=os.getenv("API_HOST", "127.0.0.1"),
+            api_port=int(os.getenv("API_PORT", "8080")),
+            settings=Settings.from_env(),
+        )

+ 200 - 0
cargo_service.py

@@ -0,0 +1,200 @@
+import os
+import threading
+import time
+from typing import Any, Dict, Optional
+
+import cv2
+import numpy as np
+
+from api_config import ApiConfig
+from depth_common import (
+    TemporalFilter,
+    compute_roi_bounds,
+    extract_depth_data,
+    find_nearest_point,
+    init_depth_pipeline,
+    nearest_distance_in_roi,
+)
+from utils import frame_to_bgr_image
+
+
+class CargoHeightService:
+    def __init__(self, config: ApiConfig) -> None:
+        self.config = config
+        self._pipeline = None
+        self._depth_intrinsics = None
+        self._temporal_filter = None
+        self._lock = threading.Lock()
+
+    def startup(self) -> None:
+        if self._pipeline is not None:
+            return
+        try:
+            pipeline, depth_intrinsics, _ = init_depth_pipeline()
+        except Exception as exc:
+            raise RuntimeError(f"Failed to init depth camera: {exc}") from exc
+        self._pipeline = pipeline
+        self._depth_intrinsics = depth_intrinsics
+        self._temporal_filter = TemporalFilter(alpha=0.5)
+
+    def shutdown(self) -> None:
+        if self._pipeline is None:
+            return
+        self._pipeline.stop()
+        self._pipeline = None
+
+    def measure_height(self) -> Optional[Dict[str, Any]]:
+        start_time = time.time()
+        samples = []
+        first_valid_sample = None
+        first_color_frame = None
+
+        with self._lock:
+            while len(samples) < self.config.sample_count and (time.time() - start_time) < self.config.sample_timeout_sec:
+                sample = self._measure_once()
+                if sample is None:
+                    continue
+                samples.append(sample["nearest_distance"])
+                if first_valid_sample is None:
+                    first_valid_sample = sample
+                if first_color_frame is None and sample.get("color_frame") is not None:
+                    first_color_frame = sample.get("color_frame")
+
+            if first_color_frame is None:
+                for _ in range(5):
+                    frames = self._pipeline.wait_for_frames(self.config.frame_timeout_ms)
+                    if frames is None:
+                        continue
+                    color_frame = frames.get_color_frame()
+                    if color_frame is not None:
+                        first_color_frame = color_frame
+                        break
+
+            if first_valid_sample is not None:
+                if first_color_frame is not None:
+                    first_valid_sample["color_frame"] = first_color_frame
+                self._save_current_sample_images(first_valid_sample)
+
+        if len(samples) < self.config.sample_count:
+            return None
+
+        median_value = int(np.median(np.array(samples, dtype=np.int32)))
+        return {
+            "height_mm": median_value,
+            "samples": samples,
+            "unit": "mm",
+            "sample_count": self.config.sample_count,
+        }
+
+    def _measure_once(self) -> Optional[Dict[str, Any]]:
+        frames = self._pipeline.wait_for_frames(self.config.frame_timeout_ms)
+        if frames is None:
+            return None
+
+        color_frame = frames.get_color_frame()
+        depth_frame = frames.get_depth_frame()
+        depth_data = extract_depth_data(depth_frame, self.config.settings, self._temporal_filter)
+        if depth_data is None:
+            return None
+
+        bounds = compute_roi_bounds(depth_data, self._depth_intrinsics, self.config.settings)
+        if bounds is None:
+            return None
+
+        x_start, x_end, y_start, y_end, center_distance = bounds
+        roi = depth_data[y_start:y_end, x_start:x_end]
+        nearest_distance = nearest_distance_in_roi(roi, self.config.settings)
+        if nearest_distance is None:
+            return None
+
+        return {
+            "nearest_distance": nearest_distance,
+            "color_frame": color_frame,
+            "depth_data": depth_data,
+            "bounds": bounds,
+            "center_distance": center_distance,
+        }
+
+    def _save_current_sample_images(self, sample: Dict[str, Any]) -> None:
+        save_image_dir = os.path.join(os.getcwd(), "sample_images")
+        os.makedirs(save_image_dir, exist_ok=True)
+
+        now = time.localtime()
+        time_str = time.strftime("%Y%m%d_%H%M%S", now)
+        millis = int((time.time() % 1) * 1000)
+        timestamp = f"{time_str}_{millis:03d}"
+
+        color_frame = sample.get("color_frame")
+        if color_frame is not None:
+            color_image = frame_to_bgr_image(color_frame)
+            if color_image is not None:
+                color_height, color_width = color_image.shape[:2]
+                color_file = os.path.join(
+                    save_image_dir,
+                    f"color_{color_width}x{color_height}_{timestamp}.png",
+                )
+                cv2.imwrite(color_file, color_image)
+
+        depth_data = sample["depth_data"]
+        x_start, x_end, y_start, y_end, center_distance = sample["bounds"]
+        nearest_distance = sample["nearest_distance"]
+        roi = depth_data[y_start:y_end, x_start:x_end]
+        nearest_point = find_nearest_point(roi, x_start, y_start, self.config.settings, nearest_distance)
+
+        depth_image = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
+        depth_image = cv2.applyColorMap(depth_image, cv2.COLORMAP_JET)
+        cv2.rectangle(
+            depth_image,
+            (x_start, y_start),
+            (x_end - 1, y_end - 1),
+            (0, 255, 0),
+            2,
+        )
+        if nearest_point is not None:
+            cv2.circle(depth_image, nearest_point, 4, (0, 0, 0), -1)
+            cv2.circle(depth_image, nearest_point, 6, (0, 255, 255), 2)
+
+        cv2.putText(
+            depth_image,
+            f"nearest: {nearest_distance} mm",
+            (10, 30),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            0.8,
+            (255, 255, 255),
+            2,
+            cv2.LINE_AA,
+        )
+        cv2.putText(
+            depth_image,
+            f"center: {int(center_distance)} mm",
+            (10, 60),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            0.8,
+            (255, 255, 255),
+            2,
+            cv2.LINE_AA,
+        )
+
+        depth_h, depth_w = depth_image.shape[:2]
+        depth_file = os.path.join(
+            save_image_dir,
+            f"depth_annotated_{depth_w}x{depth_h}_{timestamp}.png",
+        )
+        cv2.imwrite(depth_file, depth_image)
+        self._prune_saved_images(save_image_dir, self.config.max_saved_images)
+
+    @staticmethod
+    def _prune_saved_images(save_dir: str, max_images: int) -> None:
+        png_files = [
+            os.path.join(save_dir, name)
+            for name in os.listdir(save_dir)
+            if name.lower().endswith(".png")
+        ]
+        if len(png_files) <= max_images:
+            return
+        png_files.sort(key=os.path.getmtime)
+        for file_path in png_files[: len(png_files) - max_images]:
+            try:
+                os.remove(file_path)
+            except OSError:
+                pass

+ 97 - 0
request_logging.py

@@ -0,0 +1,97 @@
+import json
+import logging
+import os
+import time
+from datetime import datetime
+from logging.handlers import RotatingFileHandler
+
+from fastapi import FastAPI, Request
+from starlette.concurrency import iterate_in_threadpool
+
+
+def setup_request_logging(app: FastAPI, max_response_len: int, max_bytes: int, backup_count: int) -> None:
+    logger = logging.getLogger("cargo_height.request")
+    _setup_request_logger(logger, max_bytes=max_bytes, backup_count=backup_count)
+
+    @app.middleware("http")
+    async def request_log_middleware(request: Request, call_next):
+        request_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
+        start = time.perf_counter()
+        try:
+            response = await call_next(request)
+        except Exception:
+            elapsed_ms = (time.perf_counter() - start) * 1000
+            logger.exception(
+                "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
+                request_time,
+                request.method,
+                request.url.path,
+                500,
+                elapsed_ms,
+                "internal_error",
+            )
+            raise
+
+        body = b""
+        async for chunk in response.body_iterator:
+            body += chunk
+        response.body_iterator = iterate_in_threadpool(iter([body]))
+
+        response_text = _parse_response_text(body, response.headers.get("content-type", ""))
+        if len(response_text) > max_response_len:
+            response_text = response_text[:max_response_len] + "...(truncated)"
+
+        elapsed_ms = (time.perf_counter() - start) * 1000
+        logger.info(
+            "request_time=%s method=%s path=%s status=%s duration_ms=%.2f response=%s",
+            request_time,
+            request.method,
+            request.url.path,
+            response.status_code,
+            elapsed_ms,
+            response_text,
+        )
+        return response
+
+
+def _setup_request_logger(logger: logging.Logger, max_bytes: int, backup_count: int) -> None:
+    log_dir = os.path.join(os.getcwd(), "Log")
+    os.makedirs(log_dir, exist_ok=True)
+    log_file = os.path.join(log_dir, "request.log")
+
+    logger.setLevel(logging.INFO)
+    logger.propagate = False
+    if logger.handlers:
+        return
+
+    formatter = logging.Formatter(
+        "%(asctime)s [%(levelname)s] %(name)s - %(message)s",
+        "%Y-%m-%d %H:%M:%S",
+    )
+
+    file_handler = RotatingFileHandler(
+        log_file,
+        maxBytes=max_bytes,
+        backupCount=backup_count,
+        encoding="utf-8",
+    )
+    file_handler.setLevel(logging.INFO)
+    file_handler.setFormatter(formatter)
+
+    stream_handler = logging.StreamHandler()
+    stream_handler.setLevel(logging.INFO)
+    stream_handler.setFormatter(formatter)
+
+    logger.addHandler(file_handler)
+    logger.addHandler(stream_handler)
+
+
+def _parse_response_text(body: bytes, content_type: str) -> str:
+    if not body:
+        return ""
+    if "application/json" in content_type:
+        try:
+            return json.dumps(json.loads(body), ensure_ascii=False)
+        except Exception:
+            return body.decode("utf-8", errors="replace")
+    return body.decode("utf-8", errors="replace")