hazard.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. """
  2. Hazard detection routes.
  3. """
  4. from typing import Any, Dict, List, Optional
  5. import io
  6. import json
  7. import time
  8. import httpx
  9. from fastapi import APIRouter, Depends, Request
  10. from pydantic import BaseModel
  11. from sqlalchemy.orm import Session
  12. from PIL import Image, ImageDraw, ImageFont
  13. from database import get_db
  14. from models.scene import RecognitionRecord
  15. from services.oss_service import oss_service
  16. from services.yolo_service import yolo_service
  17. from utils.crypto import decrypt_url
  18. from utils.logger import logger
  19. router = APIRouter()
  20. class HazardRequest(BaseModel):
  21. """Compatible request model for old and new frontend payloads."""
  22. image_url: Optional[str] = None
  23. image: Optional[str] = None
  24. scene_type: str = ""
  25. scene_name: str = ""
  26. user_name: str = ""
  27. username: str = ""
  28. user_account: str = ""
  29. account: str = ""
  30. date: str = ""
  31. class SaveStepRequest(BaseModel):
  32. """Save current step for a recognition record."""
  33. record_id: int
  34. current_step: int
  35. SCENE_KEY_ALIASES = {
  36. "tunnel": "tunnel",
  37. "隧道": "tunnel",
  38. "隧道施工": "tunnel",
  39. "隧道工程": "tunnel",
  40. "simple_supported_bridge": "simple_supported_bridge",
  41. "bridge": "simple_supported_bridge",
  42. "桥梁": "simple_supported_bridge",
  43. "桥梁施工": "simple_supported_bridge",
  44. "桥梁工程": "simple_supported_bridge",
  45. "gas_station": "gas_station",
  46. "加油站": "gas_station",
  47. "special_equipment": "special_equipment",
  48. "特种设备": "special_equipment",
  49. "operate_highway": "operate_highway",
  50. "运营高速公路": "operate_highway",
  51. }
  52. SCENE_DISPLAY_NAMES = {
  53. "tunnel": "隧道工程",
  54. "simple_supported_bridge": "桥梁工程",
  55. "gas_station": "加油站",
  56. "special_equipment": "特种设备",
  57. "operate_highway": "运营高速公路",
  58. }
  59. def _get_user_code(user: Any) -> str:
  60. return (
  61. getattr(user, "userCode", None)
  62. or getattr(user, "user_code", None)
  63. or getattr(user, "account", "")
  64. )
  65. def _resolve_scene_key(scene_value: str) -> str:
  66. if not scene_value:
  67. return ""
  68. return SCENE_KEY_ALIASES.get(scene_value.strip(), scene_value.strip())
  69. def _unique_ordered(items: List[str]) -> List[str]:
  70. seen = set()
  71. ordered = []
  72. for item in items:
  73. if not item or item in seen:
  74. continue
  75. seen.add(item)
  76. ordered.append(item)
  77. return ordered
  78. def _build_frontend_result(hazards: List[Dict[str, Any]]) -> Dict[str, Any]:
  79. raw_labels: List[str] = []
  80. element_hazards: Dict[str, List[str]] = {}
  81. detections: List[Dict[str, Any]] = []
  82. for hazard in hazards:
  83. label = str(hazard.get("label") or "").strip()
  84. if not label:
  85. continue
  86. raw_labels.append(label)
  87. element_hazards.setdefault(label, [])
  88. if label not in element_hazards[label]:
  89. element_hazards[label].append(label)
  90. box = hazard.get("bbox") or hazard.get("box") or []
  91. detections.append(
  92. {
  93. "label": label,
  94. "box": box,
  95. "bbox": box,
  96. "confidence": hazard.get("confidence", 0),
  97. }
  98. )
  99. display_labels = _unique_ordered(raw_labels)
  100. return {
  101. "display_labels": display_labels,
  102. "labels": display_labels,
  103. "third_scenes": display_labels,
  104. "element_hazards": element_hazards,
  105. "detections": detections,
  106. }
  107. @router.post("/hazard")
  108. async def hazard(
  109. request: Request,
  110. data: HazardRequest,
  111. db: Session = Depends(get_db),
  112. ):
  113. """Run hazard detection and return a frontend-compatible payload."""
  114. user = request.state.user
  115. if not user:
  116. return {"statusCode": 401, "msg": "未授权"}
  117. try:
  118. source_image_url = data.image_url or data.image
  119. if not source_image_url:
  120. return {"statusCode": 422, "msg": "image_url 不能为空"}
  121. scene_key = _resolve_scene_key(data.scene_type or data.scene_name)
  122. user_code = _get_user_code(user)
  123. user_name = (
  124. data.user_name
  125. or data.username
  126. or getattr(user, "name", None)
  127. or getattr(user, "username", None)
  128. or getattr(user, "account", "")
  129. )
  130. user_account = (
  131. data.user_account
  132. or data.account
  133. or getattr(user, "account", "")
  134. )
  135. try:
  136. real_image_url = decrypt_url(source_image_url)
  137. except Exception:
  138. real_image_url = source_image_url
  139. async with httpx.AsyncClient(timeout=30.0) as client:
  140. img_response = await client.get(real_image_url)
  141. img_response.raise_for_status()
  142. image_bytes = img_response.content
  143. yolo_result = await yolo_service.detect_hazards(real_image_url, scene_key)
  144. hazards = yolo_result.get("hazards", []) or []
  145. hazard_count = len(hazards)
  146. frontend_result = _build_frontend_result(hazards)
  147. current_ts = int(time.time())
  148. result_image_bytes = await _draw_boxes_and_watermark(
  149. image_bytes,
  150. hazards,
  151. user_name=user_name,
  152. user_account=user_account,
  153. )
  154. result_filename = f"hazard_detection/{user_code}/{current_ts}.jpg"
  155. result_url = await oss_service.upload_bytes(result_image_bytes, result_filename)
  156. scene_display_name = SCENE_DISPLAY_NAMES.get(scene_key, scene_key or "隐患提示")
  157. record = RecognitionRecord(
  158. user_id=user_code,
  159. scene_type=scene_key,
  160. original_image_url=source_image_url,
  161. recognition_image_url=result_url,
  162. hazard_count=hazard_count,
  163. hazard_details=json.dumps(hazards, ensure_ascii=False),
  164. current_step=1,
  165. title=f"{scene_display_name}隐患提示",
  166. description=" ".join(frontend_result["third_scenes"]),
  167. labels=",".join(frontend_result["display_labels"]),
  168. tag_type=scene_key,
  169. created_at=current_ts,
  170. updated_at=current_ts,
  171. is_deleted=0,
  172. )
  173. db.add(record)
  174. db.commit()
  175. db.refresh(record)
  176. return {
  177. "statusCode": 200,
  178. "msg": "识别成功",
  179. "data": {
  180. "record_id": record.id,
  181. "hazard_count": hazard_count,
  182. "hazards": hazards,
  183. "scene_name": scene_key,
  184. "annotated_image": result_url,
  185. "display_labels": frontend_result["display_labels"],
  186. "labels": frontend_result["labels"],
  187. "third_scenes": frontend_result["third_scenes"],
  188. "element_hazards": frontend_result["element_hazards"],
  189. "detections": frontend_result["detections"],
  190. "result_image_url": result_url,
  191. "original_image_url": source_image_url,
  192. },
  193. }
  194. except httpx.HTTPError as e:
  195. logger.error(f"[hazard] 图片下载失败: {e}")
  196. db.rollback()
  197. return {"statusCode": 500, "msg": f"图片下载失败: {str(e)}"}
  198. except Exception as e:
  199. logger.error(f"[hazard] 处理异常: {e}")
  200. db.rollback()
  201. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  202. @router.post("/save_step")
  203. async def save_step(
  204. request: Request,
  205. data: SaveStepRequest,
  206. db: Session = Depends(get_db),
  207. ):
  208. """Update RecognitionRecord.current_step."""
  209. user = request.state.user
  210. if not user:
  211. return {"statusCode": 401, "msg": "未授权"}
  212. try:
  213. affected = (
  214. db.query(RecognitionRecord)
  215. .filter(
  216. RecognitionRecord.id == data.record_id,
  217. RecognitionRecord.user_id == _get_user_code(user),
  218. )
  219. .update(
  220. {
  221. "current_step": data.current_step,
  222. "updated_at": int(time.time()),
  223. }
  224. )
  225. )
  226. if affected == 0:
  227. return {"statusCode": 404, "msg": "记录不存在"}
  228. db.commit()
  229. return {
  230. "statusCode": 200,
  231. "msg": "保存成功",
  232. "data": {
  233. "record_id": data.record_id,
  234. "current_step": data.current_step,
  235. },
  236. }
  237. except Exception as e:
  238. logger.error(f"[save_step] 异常: {e}")
  239. db.rollback()
  240. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
  241. async def _draw_boxes_and_watermark(
  242. image_bytes: bytes,
  243. hazards: List[Dict[str, Any]],
  244. user_name: str,
  245. user_account: str,
  246. ) -> bytes:
  247. """Draw detection boxes and a tiled watermark on the image."""
  248. try:
  249. image = Image.open(io.BytesIO(image_bytes)).convert("RGBA")
  250. width, height = image.size
  251. overlay = Image.new("RGBA", (width, height), (255, 255, 255, 0))
  252. draw = ImageDraw.Draw(overlay)
  253. try:
  254. font = ImageFont.truetype(
  255. "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20
  256. )
  257. font_small = ImageFont.truetype(
  258. "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14
  259. )
  260. except Exception:
  261. try:
  262. font = ImageFont.truetype("C:/Windows/Fonts/msyh.ttc", 20)
  263. font_small = ImageFont.truetype("C:/Windows/Fonts/msyh.ttc", 14)
  264. except Exception:
  265. font = ImageFont.load_default()
  266. font_small = ImageFont.load_default()
  267. for hazard in hazards:
  268. bbox = hazard.get("bbox", []) or hazard.get("box", [])
  269. label = hazard.get("label", "")
  270. confidence = hazard.get("confidence", 0)
  271. if len(bbox) == 4:
  272. x1, y1, x2, y2 = bbox
  273. draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0, 255), width=3)
  274. text = f"{label} {confidence:.2f}"
  275. draw.text(
  276. (x1, max(0, y1 - 25)),
  277. text,
  278. fill=(255, 0, 0, 255),
  279. font=font,
  280. )
  281. current_date = time.strftime("%Y/%m/%d")
  282. watermarks = [user_name or "", user_account or "", current_date]
  283. watermarks = [text for text in watermarks if text]
  284. if not watermarks:
  285. watermarks = [current_date]
  286. text_height_estimate = 50
  287. text_width_estimate = 150
  288. angle = 45
  289. watermark_layer = Image.new(
  290. "RGBA", (width * 2, height * 2), (255, 255, 255, 0)
  291. )
  292. watermark_draw = ImageDraw.Draw(watermark_layer)
  293. for y in range(-height, height * 2, text_height_estimate):
  294. for x in range(-width, width * 2, text_width_estimate):
  295. row_index = int(y / text_height_estimate) % len(watermarks)
  296. watermark_draw.text(
  297. (x, y),
  298. watermarks[row_index],
  299. fill=(128, 128, 128, 60),
  300. font=font_small,
  301. )
  302. watermark_layer = watermark_layer.rotate(
  303. angle, expand=False, fillcolor=(255, 255, 255, 0)
  304. )
  305. crop_x = (watermark_layer.width - width) // 2
  306. crop_y = (watermark_layer.height - height) // 2
  307. watermark_layer = watermark_layer.crop(
  308. (crop_x, crop_y, crop_x + width, crop_y + height)
  309. )
  310. image = Image.alpha_composite(image, watermark_layer)
  311. image = Image.alpha_composite(image, overlay)
  312. final_image = image.convert("RGB")
  313. output = io.BytesIO()
  314. final_image.save(output, format="JPEG", quality=95)
  315. return output.getvalue()
  316. except Exception as e:
  317. logger.error(f"[_draw_boxes_and_watermark] 图片处理失败: {e}")
  318. return image_bytes