| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- #!/usr/bin/env python3
- from PIL import Image
- from ultralytics import YOLO
- from fastapi.responses import JSONResponse
- from fastapi import FastAPI, Request
- import base64
- import os
- import io
- from setproctitle import setproctitle
- setproctitle("yolo-router")
- BASE_DIR = "/tmp/yolov11"
- MODEL_PATHS = {
- "jianzhiliang": os.path.join(BASE_DIR, "jianzhiliang/models/trained/jianzhiliang.pt"),
- "gsgl": os.path.join(BASE_DIR, "gsgl/models/trained/gsgl.pt"),
- "suidao": os.path.join(BASE_DIR, "suidao/models/trained/suidao.pt"),
- "tezhongshebei": os.path.join(BASE_DIR, "tezhongshebei/models/trained/tezhongshebei.pt"),
- "jiayouzhan": os.path.join(BASE_DIR, "jiayouzhan/models/trained/jiayouzhan.pt"),
- }
- MODELTYPE_MAP = {
- "simple_supported_bridge": "jianzhiliang",
- "special_equipment": "tezhongshebei",
- "gas_station": "jiayouzhan",
- "operate_highway": "gsgl",
- "tunnel": "suidao",
- }
- models = {}
- for name, path in MODEL_PATHS.items():
- if os.path.exists(path):
- print(f"Loading model {name} from {path}")
- models[name] = YOLO(path)
- else:
- print(f"WARNING: model path not found: {path}")
- app = FastAPI(title="YOLO Router", version="1.0")
- def _parse_conf_threshold(raw_value):
- try:
- if raw_value is None or raw_value == "":
- return 0.5
- threshold = float(raw_value)
- if threshold < 0:
- return 0.0
- if threshold > 1:
- return 1.0
- return threshold
- except (TypeError, ValueError):
- return 0.5
- @app.post("/predict")
- async def predict(request: Request):
- try:
- data = await request.json()
- except Exception as e:
- return JSONResponse(status_code=400, content={"error": f"Invalid JSON: {str(e)}"})
- modeltype = data.get("modeltype")
- image_b64 = data.get("image")
- conf_threshold = _parse_conf_threshold(data.get("conf_threshold"))
- if not modeltype or not image_b64:
- return JSONResponse(status_code=422, content={"error": "Missing modeltype or image"})
- model_name = MODELTYPE_MAP.get(modeltype)
- if not model_name or model_name not in models:
- return JSONResponse(status_code=400, content={"error": f"Unknown modeltype '{modeltype}' or model not loaded"})
- try:
- img_bytes = base64.b64decode(image_b64)
- image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
- except Exception as e:
- return JSONResponse(status_code=422, content={"error": f"Image decode error: {str(e)}"})
- try:
- results = models[model_name](image, conf=conf_threshold, verbose=False)
- detections = []
- labels = []
- boxes_list = []
- scores = []
- if len(results) > 0:
- result = results[0]
- boxes = result.boxes
- if boxes is not None:
- for box in boxes:
- cls_id = int(box.cls[0])
- conf = float(box.conf[0])
- xyxy = box.xyxy[0].tolist()
- label = result.names[cls_id]
- detections.append({
- "label": label,
- "confidence": round(conf, 4),
- "bbox": [round(x, 2) for x in xyxy]
- })
- labels.append(label)
- boxes_list.append([round(x, 2) for x in xyxy])
- scores.append(round(conf, 4))
- # ★ 打印目标数量到日志
- print(f"[{modeltype}] Detected {len(detections)} objects")
- # 同时兼容项目后端协议和当前调试使用的 detections 结构。
- return {
- "model_type": modeltype,
- "modeltype": modeltype,
- "labels": labels,
- "boxes": boxes_list,
- "scores": scores,
- "count": len(detections),
- "detections": detections
- }
- except Exception as e:
- print(f"Inference error: {e}")
- return JSONResponse(status_code=500, content={"error": f"Inference error: {str(e)}"})
- @app.get("/health")
- async def health():
- return {"status": "ok", "loaded_models": list(models.keys())}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=18080)
|