scene.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. import json
  2. import time
  3. from typing import Any, Optional
  4. from fastapi import APIRouter, Depends, Request
  5. from pydantic import BaseModel
  6. from sqlalchemy.orm import Session
  7. from database import get_db
  8. from models.scene import (
  9. FirstScene,
  10. RecognitionRecord,
  11. Scene,
  12. SceneTemplate,
  13. SecondScene,
  14. ThirdScene,
  15. )
  16. router = APIRouter()
  17. def _get_user_code(user: Any) -> str:
  18. return (
  19. getattr(user, "userCode", None)
  20. or getattr(user, "user_code", None)
  21. or getattr(user, "account", "")
  22. )
  23. def _load_hazard_details(record: RecognitionRecord):
  24. if not record.hazard_details:
  25. return []
  26. try:
  27. data = json.loads(record.hazard_details)
  28. return data if isinstance(data, list) else []
  29. except Exception:
  30. return []
  31. def _split_labels(labels):
  32. if not labels:
  33. return []
  34. if isinstance(labels, list):
  35. return [str(item).strip() for item in labels if str(item).strip()]
  36. return [
  37. item.strip()
  38. for item in str(labels).replace(",", ",").split(",")
  39. if item.strip()
  40. ]
  41. def _unique_ordered(items):
  42. seen = set()
  43. ordered = []
  44. for item in items:
  45. if not item or item in seen:
  46. continue
  47. seen.add(item)
  48. ordered.append(item)
  49. return ordered
  50. def _build_record_view(record: RecognitionRecord):
  51. hazard_details = _load_hazard_details(record)
  52. derived_labels = _unique_ordered(
  53. [
  54. str(item.get("label") or "").strip()
  55. for item in hazard_details
  56. if str(item.get("label") or "").strip()
  57. ]
  58. )
  59. display_labels = _split_labels(record.labels) or derived_labels
  60. if record.description:
  61. third_scenes = [item for item in str(record.description).split(" ") if item]
  62. else:
  63. third_scenes = derived_labels
  64. detections = [
  65. {
  66. "label": item.get("label", ""),
  67. "box": item.get("bbox") or item.get("box") or [],
  68. "bbox": item.get("bbox") or item.get("box") or [],
  69. "confidence": item.get("confidence", 0),
  70. }
  71. for item in hazard_details
  72. ]
  73. return {
  74. "id": record.id,
  75. "title": record.title or "隐患提示记录",
  76. "description": record.description or " ".join(third_scenes),
  77. "original_image_url": record.original_image_url,
  78. "recognition_image_url": record.recognition_image_url,
  79. "labels": record.labels or ",".join(display_labels),
  80. "display_labels": display_labels,
  81. "third_scenes": third_scenes,
  82. "tag_type": record.tag_type or record.scene_type,
  83. "scene_type": record.scene_type,
  84. "effect_evaluation": record.effect_evaluation,
  85. "hazard_details": hazard_details,
  86. "detections": detections,
  87. }
  88. def _resolve_record_id(
  89. recognition_id: Optional[int] = None,
  90. recognition_record_id: Optional[int] = None,
  91. ):
  92. return recognition_id or recognition_record_id
  93. @router.get("/get_scene_list")
  94. async def get_scene_list(db: Session = Depends(get_db)):
  95. scenes = db.query(Scene).filter(Scene.is_deleted == 0).all()
  96. return {
  97. "statusCode": 200,
  98. "msg": "success",
  99. "data": [
  100. {
  101. "id": s.id,
  102. "scene_name": s.scene_name,
  103. "scene_en_name": s.scene_en_name,
  104. }
  105. for s in scenes
  106. ],
  107. }
  108. @router.get("/get_first_scene_list")
  109. async def get_first_scene_list(scene_id: int, db: Session = Depends(get_db)):
  110. scenes = (
  111. db.query(FirstScene)
  112. .filter(FirstScene.scene_id == scene_id, FirstScene.is_deleted == 0)
  113. .all()
  114. )
  115. return {
  116. "statusCode": 200,
  117. "msg": "success",
  118. "data": [{"id": s.id, "first_scene_name": s.first_scene_name} for s in scenes],
  119. }
  120. @router.get("/get_second_scene_list")
  121. async def get_second_scene_list(
  122. first_scene_id: int, db: Session = Depends(get_db)
  123. ):
  124. scenes = (
  125. db.query(SecondScene)
  126. .filter(
  127. SecondScene.first_scene_id == first_scene_id,
  128. SecondScene.is_deleted == 0,
  129. )
  130. .all()
  131. )
  132. return {
  133. "statusCode": 200,
  134. "msg": "success",
  135. "data": [{"id": s.id, "second_scene_name": s.second_scene_name} for s in scenes],
  136. }
  137. @router.get("/get_third_scene_list")
  138. async def get_third_scene_list(
  139. second_scene_id: int, db: Session = Depends(get_db)
  140. ):
  141. scenes = (
  142. db.query(ThirdScene)
  143. .filter(
  144. ThirdScene.second_scene_id == second_scene_id,
  145. ThirdScene.is_deleted == 0,
  146. )
  147. .all()
  148. )
  149. return {
  150. "statusCode": 200,
  151. "msg": "success",
  152. "data": [
  153. {
  154. "id": s.id,
  155. "third_scene_name": s.third_scene_name,
  156. "correct_example_image": s.correct_example_image,
  157. "wrong_example_image": s.wrong_example_image,
  158. }
  159. for s in scenes
  160. ],
  161. }
  162. @router.get("/get_third_scene_example_image")
  163. async def get_third_scene_example_image(
  164. third_scene_name: str, db: Session = Depends(get_db)
  165. ):
  166. if not third_scene_name:
  167. return {"statusCode": 400, "msg": "三级场景名称不能为空"}
  168. scene = (
  169. db.query(ThirdScene)
  170. .filter(
  171. ThirdScene.third_scene_name == third_scene_name,
  172. ThirdScene.is_deleted == 0,
  173. )
  174. .first()
  175. )
  176. if not scene:
  177. return {"statusCode": 404, "msg": "三级场景不存在"}
  178. return {
  179. "statusCode": 200,
  180. "msg": "success",
  181. "data": {
  182. "id": scene.id,
  183. "third_scene_name": scene.third_scene_name,
  184. "correct_example_image": scene.correct_example_image,
  185. "wrong_example_image": scene.wrong_example_image,
  186. },
  187. }
  188. @router.get("/get_history_recognition_record")
  189. async def get_history_recognition_record(
  190. request: Request, db: Session = Depends(get_db)
  191. ):
  192. user = request.state.user
  193. if not user:
  194. return {"statusCode": 401, "msg": "未授权"}
  195. user_code = _get_user_code(user)
  196. records = (
  197. db.query(RecognitionRecord)
  198. .filter(RecognitionRecord.user_id == user_code, RecognitionRecord.is_deleted == 0)
  199. .order_by(RecognitionRecord.updated_at.desc())
  200. .all()
  201. )
  202. total = (
  203. db.query(RecognitionRecord)
  204. .filter(RecognitionRecord.user_id == user_code, RecognitionRecord.is_deleted == 0)
  205. .count()
  206. )
  207. return {
  208. "statusCode": 200,
  209. "msg": "success",
  210. "data": [
  211. {
  212. **_build_record_view(record),
  213. "created_at": record.created_at,
  214. }
  215. for record in records
  216. ],
  217. "total": total,
  218. }
  219. @router.get("/get_recognition_record_detail")
  220. async def get_recognition_record_detail(
  221. recognition_id: Optional[int] = None,
  222. recognition_record_id: Optional[int] = None,
  223. db: Session = Depends(get_db),
  224. ):
  225. record_id = _resolve_record_id(recognition_id, recognition_record_id)
  226. if not record_id:
  227. return {"statusCode": 422, "msg": "recognition_id 不能为空"}
  228. record = (
  229. db.query(RecognitionRecord)
  230. .filter(RecognitionRecord.id == record_id, RecognitionRecord.is_deleted == 0)
  231. .first()
  232. )
  233. if not record:
  234. return {"statusCode": 404, "msg": "记录不存在"}
  235. record_view = _build_record_view(record)
  236. return {
  237. "statusCode": 200,
  238. "msg": "success",
  239. "data": {
  240. "id": record.id,
  241. "user_id": record.user_id,
  242. "title": record_view["title"],
  243. "description": record_view["description"],
  244. "original_image_url": record.original_image_url,
  245. "recognition_image_url": record.recognition_image_url,
  246. "labels": record_view["labels"],
  247. "display_labels": record_view["display_labels"],
  248. "third_scenes": record_view["third_scenes"],
  249. "tag_type": record_view["tag_type"],
  250. "scene_type": record.scene_type,
  251. "scene_match": record.scene_match,
  252. "tip_accuracy": record.tip_accuracy,
  253. "effect_evaluation": record.effect_evaluation,
  254. "user_remark": record.user_remark,
  255. "hazard_details": record_view["hazard_details"],
  256. "detections": record_view["detections"],
  257. "created_at": record.created_at,
  258. "updated_at": record.updated_at,
  259. },
  260. }
  261. class DeleteRecognitionRequest(BaseModel):
  262. recognition_id: Optional[int] = None
  263. recognition_record_id: Optional[int] = None
  264. @router.post("/delete_recognition_record")
  265. async def delete_recognition_record(
  266. data: DeleteRecognitionRequest,
  267. request: Request,
  268. db: Session = Depends(get_db),
  269. ):
  270. user = request.state.user
  271. if not user:
  272. return {"statusCode": 401, "msg": "未授权"}
  273. record_id = _resolve_record_id(data.recognition_id, data.recognition_record_id)
  274. if not record_id:
  275. return {"statusCode": 422, "msg": "recognition_id 不能为空"}
  276. (
  277. db.query(RecognitionRecord)
  278. .filter(
  279. RecognitionRecord.id == record_id,
  280. RecognitionRecord.user_id == _get_user_code(user),
  281. )
  282. .update({"is_deleted": 1, "deleted_at": int(time.time())})
  283. )
  284. db.commit()
  285. return {"statusCode": 200, "msg": "删除成功"}
  286. class EvaluationRequest(BaseModel):
  287. id: int
  288. scene_match: Optional[int] = None
  289. tip_accuracy: Optional[int] = None
  290. effect_evaluation: Optional[int] = None
  291. user_remark: Optional[str] = None
  292. @router.post("/submit_evaluation")
  293. async def submit_evaluation(data: EvaluationRequest, db: Session = Depends(get_db)):
  294. record = (
  295. db.query(RecognitionRecord)
  296. .filter(RecognitionRecord.id == data.id, RecognitionRecord.is_deleted == 0)
  297. .first()
  298. )
  299. if not record:
  300. return {"statusCode": 404, "msg": "记录不存在"}
  301. if data.scene_match is not None:
  302. record.scene_match = data.scene_match
  303. if data.tip_accuracy is not None:
  304. record.tip_accuracy = data.tip_accuracy
  305. if data.effect_evaluation is not None:
  306. record.effect_evaluation = data.effect_evaluation
  307. if data.user_remark is not None:
  308. record.user_remark = data.user_remark
  309. record.updated_at = int(time.time())
  310. db.commit()
  311. return {"statusCode": 200, "msg": "success"}
  312. @router.get("/get_latest_recognition_record")
  313. async def get_latest_recognition_record(
  314. request: Request, db: Session = Depends(get_db)
  315. ):
  316. user = request.state.user
  317. if not user:
  318. return {"statusCode": 401, "msg": "未授权"}
  319. record = (
  320. db.query(RecognitionRecord)
  321. .filter(
  322. RecognitionRecord.user_id == _get_user_code(user),
  323. RecognitionRecord.is_deleted == 0,
  324. )
  325. .order_by(RecognitionRecord.created_at.desc())
  326. .first()
  327. )
  328. if not record:
  329. return {
  330. "statusCode": 200,
  331. "msg": "success",
  332. "data": {"effect_evaluation": 1},
  333. }
  334. return {
  335. "statusCode": 200,
  336. "msg": "success",
  337. "data": {
  338. "id": record.id,
  339. "title": record.title,
  340. "original_image_url": record.original_image_url,
  341. "recognition_image_url": record.recognition_image_url,
  342. "labels": record.labels,
  343. "created_at": record.created_at,
  344. "effect_evaluation": record.effect_evaluation,
  345. },
  346. }
  347. class SceneTemplateCreate(BaseModel):
  348. scene_name: str
  349. scene_type: str
  350. scene_desc: str = ""
  351. model_name: str
  352. @router.post("/scene_template")
  353. async def create_scene_template(
  354. data: SceneTemplateCreate, db: Session = Depends(get_db)
  355. ):
  356. template = SceneTemplate(
  357. scene_name=data.scene_name,
  358. scene_type=data.scene_type,
  359. scene_desc=data.scene_desc,
  360. model_name=data.model_name,
  361. created_at=int(time.time()),
  362. updated_at=int(time.time()),
  363. is_deleted=0,
  364. )
  365. db.add(template)
  366. db.commit()
  367. db.refresh(template)
  368. return {
  369. "statusCode": 200,
  370. "msg": "创建成功",
  371. "data": {"id": template.id},
  372. }
  373. @router.get("/scene_templates")
  374. async def get_scene_templates(
  375. page: int = 1,
  376. page_size: int = 20,
  377. db: Session = Depends(get_db),
  378. ):
  379. if page_size > 100:
  380. page_size = 100
  381. offset = (page - 1) * page_size
  382. total = db.query(SceneTemplate).filter(SceneTemplate.is_deleted == 0).count()
  383. templates = (
  384. db.query(SceneTemplate)
  385. .filter(SceneTemplate.is_deleted == 0)
  386. .order_by(SceneTemplate.created_at.desc())
  387. .offset(offset)
  388. .limit(page_size)
  389. .all()
  390. )
  391. return {
  392. "statusCode": 200,
  393. "msg": "success",
  394. "data": {
  395. "total": total,
  396. "items": [
  397. {
  398. "id": template.id,
  399. "scene_name": template.scene_name,
  400. "scene_type": template.scene_type,
  401. "scene_desc": template.scene_desc,
  402. "model_name": template.model_name,
  403. "created_at": template.created_at,
  404. }
  405. for template in templates
  406. ],
  407. },
  408. }
  409. @router.get("/recognition_records")
  410. async def get_recognition_records(
  411. request: Request,
  412. scene_type: str = "",
  413. page: int = 1,
  414. page_size: int = 20,
  415. db: Session = Depends(get_db),
  416. ):
  417. user = request.state.user
  418. if not user:
  419. return {"statusCode": 401, "msg": "未授权"}
  420. if page_size > 100:
  421. page_size = 100
  422. query = db.query(RecognitionRecord).filter(
  423. RecognitionRecord.user_id == _get_user_code(user),
  424. RecognitionRecord.is_deleted == 0,
  425. )
  426. if scene_type:
  427. query = query.filter(RecognitionRecord.scene_type == scene_type)
  428. total = query.count()
  429. offset = (page - 1) * page_size
  430. records = (
  431. query.order_by(RecognitionRecord.created_at.desc())
  432. .offset(offset)
  433. .limit(page_size)
  434. .all()
  435. )
  436. return {
  437. "statusCode": 200,
  438. "msg": "success",
  439. "data": {
  440. "total": total,
  441. "items": [
  442. {
  443. "id": record.id,
  444. "scene_type": record.scene_type,
  445. "original_image_url": record.original_image_url,
  446. "result_image_url": record.recognition_image_url,
  447. "hazard_count": record.hazard_count,
  448. "current_step": record.current_step,
  449. "created_at": record.created_at,
  450. }
  451. for record in records
  452. ],
  453. },
  454. }