app.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #!/usr/bin/env python3
  2. from PIL import Image
  3. from ultralytics import YOLO
  4. from fastapi.responses import JSONResponse
  5. from fastapi import FastAPI, Request
  6. import base64
  7. import os
  8. import io
  9. from setproctitle import setproctitle
  10. setproctitle("yolo-router")
  11. BASE_DIR = "/tmp/yolov11"
  12. MODEL_PATHS = {
  13. "jianzhiliang": os.path.join(BASE_DIR, "jianzhiliang/models/trained/jianzhiliang.pt"),
  14. "gsgl": os.path.join(BASE_DIR, "gsgl/models/trained/gsgl.pt"),
  15. "suidao": os.path.join(BASE_DIR, "suidao/models/trained/suidao.pt"),
  16. "tezhongshebei": os.path.join(BASE_DIR, "tezhongshebei/models/trained/tezhongshebei.pt"),
  17. "jiayouzhan": os.path.join(BASE_DIR, "jiayouzhan/models/trained/jiayouzhan.pt"),
  18. }
  19. MODELTYPE_MAP = {
  20. "simple_supported_bridge": "jianzhiliang",
  21. "special_equipment": "tezhongshebei",
  22. "gas_station": "jiayouzhan",
  23. "operate_highway": "gsgl",
  24. "tunnel": "suidao",
  25. }
  26. models = {}
  27. for name, path in MODEL_PATHS.items():
  28. if os.path.exists(path):
  29. print(f"Loading model {name} from {path}")
  30. models[name] = YOLO(path)
  31. else:
  32. print(f"WARNING: model path not found: {path}")
  33. app = FastAPI(title="YOLO Router", version="1.0")
  34. def _parse_conf_threshold(raw_value):
  35. try:
  36. if raw_value is None or raw_value == "":
  37. return 0.5
  38. threshold = float(raw_value)
  39. if threshold < 0:
  40. return 0.0
  41. if threshold > 1:
  42. return 1.0
  43. return threshold
  44. except (TypeError, ValueError):
  45. return 0.5
  46. @app.post("/predict")
  47. async def predict(request: Request):
  48. try:
  49. data = await request.json()
  50. except Exception as e:
  51. return JSONResponse(status_code=400, content={"error": f"Invalid JSON: {str(e)}"})
  52. modeltype = data.get("modeltype")
  53. image_b64 = data.get("image")
  54. conf_threshold = _parse_conf_threshold(data.get("conf_threshold"))
  55. if not modeltype or not image_b64:
  56. return JSONResponse(status_code=422, content={"error": "Missing modeltype or image"})
  57. model_name = MODELTYPE_MAP.get(modeltype)
  58. if not model_name or model_name not in models:
  59. return JSONResponse(status_code=400, content={"error": f"Unknown modeltype '{modeltype}' or model not loaded"})
  60. try:
  61. img_bytes = base64.b64decode(image_b64)
  62. image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
  63. except Exception as e:
  64. return JSONResponse(status_code=422, content={"error": f"Image decode error: {str(e)}"})
  65. try:
  66. results = models[model_name](image, conf=conf_threshold, verbose=False)
  67. detections = []
  68. labels = []
  69. boxes_list = []
  70. scores = []
  71. if len(results) > 0:
  72. result = results[0]
  73. boxes = result.boxes
  74. if boxes is not None:
  75. for box in boxes:
  76. cls_id = int(box.cls[0])
  77. conf = float(box.conf[0])
  78. xyxy = box.xyxy[0].tolist()
  79. label = result.names[cls_id]
  80. detections.append({
  81. "label": label,
  82. "confidence": round(conf, 4),
  83. "bbox": [round(x, 2) for x in xyxy]
  84. })
  85. labels.append(label)
  86. boxes_list.append([round(x, 2) for x in xyxy])
  87. scores.append(round(conf, 4))
  88. # ★ 打印目标数量到日志
  89. print(f"[{modeltype}] Detected {len(detections)} objects")
  90. # 同时兼容项目后端协议和当前调试使用的 detections 结构。
  91. return {
  92. "model_type": modeltype,
  93. "modeltype": modeltype,
  94. "labels": labels,
  95. "boxes": boxes_list,
  96. "scores": scores,
  97. "count": len(detections),
  98. "detections": detections
  99. }
  100. except Exception as e:
  101. print(f"Inference error: {e}")
  102. return JSONResponse(status_code=500, content={"error": f"Inference error: {str(e)}"})
  103. @app.get("/health")
  104. async def health():
  105. return {"status": "ok", "loaded_models": list(models.keys())}
  106. if __name__ == "__main__":
  107. import uvicorn
  108. uvicorn.run(app, host="0.0.0.0", port=18080)