|
|
@@ -0,0 +1,127 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+from PIL import Image
|
|
|
+from ultralytics import YOLO
|
|
|
+from fastapi.responses import JSONResponse
|
|
|
+from fastapi import FastAPI, Request
|
|
|
+import base64
|
|
|
+import os
|
|
|
+import io
|
|
|
+from setproctitle import setproctitle
|
|
|
+setproctitle("yolo-router")
|
|
|
+
|
|
|
+BASE_DIR = "/tmp/yolov11"
|
|
|
+
|
|
|
+MODEL_PATHS = {
|
|
|
+ "jianzhiliang": os.path.join(BASE_DIR, "jianzhiliang/models/trained/jianzhiliang.pt"),
|
|
|
+ "gsgl": os.path.join(BASE_DIR, "gsgl/models/trained/gsgl.pt"),
|
|
|
+ "suidao": os.path.join(BASE_DIR, "suidao/models/trained/suidao.pt"),
|
|
|
+ "tezhongshebei": os.path.join(BASE_DIR, "tezhongshebei/models/trained/tezhongshebei.pt"),
|
|
|
+ "jiayouzhan": os.path.join(BASE_DIR, "jiayouzhan/models/trained/jiayouzhan.pt"),
|
|
|
+}
|
|
|
+
|
|
|
+MODELTYPE_MAP = {
|
|
|
+ "simple_supported_bridge": "jianzhiliang",
|
|
|
+ "special_equipment": "tezhongshebei",
|
|
|
+ "gas_station": "jiayouzhan",
|
|
|
+ "operate_highway": "gsgl",
|
|
|
+ "tunnel": "suidao",
|
|
|
+}
|
|
|
+
|
|
|
+models = {}
|
|
|
+for name, path in MODEL_PATHS.items():
|
|
|
+ if os.path.exists(path):
|
|
|
+ print(f"Loading model {name} from {path}")
|
|
|
+ models[name] = YOLO(path)
|
|
|
+ else:
|
|
|
+ print(f"WARNING: model path not found: {path}")
|
|
|
+
|
|
|
+app = FastAPI(title="YOLO Router", version="1.0")
|
|
|
+
|
|
|
+
|
|
|
+def _parse_conf_threshold(raw_value):
|
|
|
+ try:
|
|
|
+ if raw_value is None or raw_value == "":
|
|
|
+ return 0.5
|
|
|
+ threshold = float(raw_value)
|
|
|
+ if threshold < 0:
|
|
|
+ return 0.0
|
|
|
+ if threshold > 1:
|
|
|
+ return 1.0
|
|
|
+ return threshold
|
|
|
+ except (TypeError, ValueError):
|
|
|
+ return 0.5
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/predict")
|
|
|
+async def predict(request: Request):
|
|
|
+ try:
|
|
|
+ data = await request.json()
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(status_code=400, content={"error": f"Invalid JSON: {str(e)}"})
|
|
|
+
|
|
|
+ modeltype = data.get("modeltype")
|
|
|
+ image_b64 = data.get("image")
|
|
|
+ conf_threshold = _parse_conf_threshold(data.get("conf_threshold"))
|
|
|
+
|
|
|
+ if not modeltype or not image_b64:
|
|
|
+ return JSONResponse(status_code=422, content={"error": "Missing modeltype or image"})
|
|
|
+
|
|
|
+ model_name = MODELTYPE_MAP.get(modeltype)
|
|
|
+ if not model_name or model_name not in models:
|
|
|
+ return JSONResponse(status_code=400, content={"error": f"Unknown modeltype '{modeltype}' or model not loaded"})
|
|
|
+
|
|
|
+ try:
|
|
|
+ img_bytes = base64.b64decode(image_b64)
|
|
|
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(status_code=422, content={"error": f"Image decode error: {str(e)}"})
|
|
|
+
|
|
|
+ try:
|
|
|
+ results = models[model_name](image, conf=conf_threshold, verbose=False)
|
|
|
+ detections = []
|
|
|
+ labels = []
|
|
|
+ boxes_list = []
|
|
|
+ scores = []
|
|
|
+ if len(results) > 0:
|
|
|
+ result = results[0]
|
|
|
+ boxes = result.boxes
|
|
|
+ if boxes is not None:
|
|
|
+ for box in boxes:
|
|
|
+ cls_id = int(box.cls[0])
|
|
|
+ conf = float(box.conf[0])
|
|
|
+ xyxy = box.xyxy[0].tolist()
|
|
|
+ label = result.names[cls_id]
|
|
|
+ detections.append({
|
|
|
+ "label": label,
|
|
|
+ "confidence": round(conf, 4),
|
|
|
+ "bbox": [round(x, 2) for x in xyxy]
|
|
|
+ })
|
|
|
+ labels.append(label)
|
|
|
+ boxes_list.append([round(x, 2) for x in xyxy])
|
|
|
+ scores.append(round(conf, 4))
|
|
|
+
|
|
|
+ # ★ 打印目标数量到日志
|
|
|
+ print(f"[{modeltype}] Detected {len(detections)} objects")
|
|
|
+
|
|
|
+ # 同时兼容项目后端协议和当前调试使用的 detections 结构。
|
|
|
+ return {
|
|
|
+ "model_type": modeltype,
|
|
|
+ "modeltype": modeltype,
|
|
|
+ "labels": labels,
|
|
|
+ "boxes": boxes_list,
|
|
|
+ "scores": scores,
|
|
|
+ "count": len(detections),
|
|
|
+ "detections": detections
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Inference error: {e}")
|
|
|
+ return JSONResponse(status_code=500, content={"error": f"Inference error: {str(e)}"})
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/health")
|
|
|
+async def health():
|
|
|
+ return {"status": "ok", "loaded_models": list(models.keys())}
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import uvicorn
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=18080)
|