#!/usr/bin/env python3 from PIL import Image from ultralytics import YOLO from fastapi.responses import JSONResponse from fastapi import FastAPI, Request import base64 import os import io from setproctitle import setproctitle setproctitle("yolo-router") BASE_DIR = "/tmp/yolov11" MODEL_PATHS = { "jianzhiliang": os.path.join(BASE_DIR, "jianzhiliang/models/trained/jianzhiliang.pt"), "gsgl": os.path.join(BASE_DIR, "gsgl/models/trained/gsgl.pt"), "suidao": os.path.join(BASE_DIR, "suidao/models/trained/suidao.pt"), "tezhongshebei": os.path.join(BASE_DIR, "tezhongshebei/models/trained/tezhongshebei.pt"), "jiayouzhan": os.path.join(BASE_DIR, "jiayouzhan/models/trained/jiayouzhan.pt"), } MODELTYPE_MAP = { "simple_supported_bridge": "jianzhiliang", "special_equipment": "tezhongshebei", "gas_station": "jiayouzhan", "operate_highway": "gsgl", "tunnel": "suidao", } models = {} for name, path in MODEL_PATHS.items(): if os.path.exists(path): print(f"Loading model {name} from {path}") models[name] = YOLO(path) else: print(f"WARNING: model path not found: {path}") app = FastAPI(title="YOLO Router", version="1.0") def _parse_conf_threshold(raw_value): try: if raw_value is None or raw_value == "": return 0.5 threshold = float(raw_value) if threshold < 0: return 0.0 if threshold > 1: return 1.0 return threshold except (TypeError, ValueError): return 0.5 @app.post("/predict") async def predict(request: Request): try: data = await request.json() except Exception as e: return JSONResponse(status_code=400, content={"error": f"Invalid JSON: {str(e)}"}) modeltype = data.get("modeltype") image_b64 = data.get("image") conf_threshold = _parse_conf_threshold(data.get("conf_threshold")) if not modeltype or not image_b64: return JSONResponse(status_code=422, content={"error": "Missing modeltype or image"}) model_name = MODELTYPE_MAP.get(modeltype) if not model_name or model_name not in models: return JSONResponse(status_code=400, content={"error": f"Unknown modeltype '{modeltype}' or model not loaded"}) try: img_bytes = base64.b64decode(image_b64) image = Image.open(io.BytesIO(img_bytes)).convert("RGB") except Exception as e: return JSONResponse(status_code=422, content={"error": f"Image decode error: {str(e)}"}) try: results = models[model_name](image, conf=conf_threshold, verbose=False) detections = [] labels = [] boxes_list = [] scores = [] if len(results) > 0: result = results[0] boxes = result.boxes if boxes is not None: for box in boxes: cls_id = int(box.cls[0]) conf = float(box.conf[0]) xyxy = box.xyxy[0].tolist() label = result.names[cls_id] detections.append({ "label": label, "confidence": round(conf, 4), "bbox": [round(x, 2) for x in xyxy] }) labels.append(label) boxes_list.append([round(x, 2) for x in xyxy]) scores.append(round(conf, 4)) # ★ 打印目标数量到日志 print(f"[{modeltype}] Detected {len(detections)} objects") # 同时兼容项目后端协议和当前调试使用的 detections 结构。 return { "model_type": modeltype, "modeltype": modeltype, "labels": labels, "boxes": boxes_list, "scores": scores, "count": len(detections), "detections": detections } except Exception as e: print(f"Inference error: {e}") return JSONResponse(status_code=500, content={"error": f"Inference error: {str(e)}"}) @app.get("/health") async def health(): return {"status": "ok", "loaded_models": list(models.keys())} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=18080)