llm_utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # -*- coding: utf-8 -*-
  2. """Small LLM output helpers."""
  3. import json
  4. import re
  5. from typing import Any, Dict, Optional
  6. _FENCED_JSON_RE = re.compile(r"```(?:json)?\s*([\s\S]*?)\s*```", re.IGNORECASE)
  7. # Regex fallback: extract "answer" value from a JSON-like structure.
  8. # Handles both "answer": "..." (double-quoted) and multi-line values.
  9. _ANSWER_FIELD_RE = re.compile(
  10. r'"answer"\s*:\s*"((?:[^"\\]|\\.)*)"',
  11. re.DOTALL,
  12. )
  13. def extract_json_object(text: str) -> Dict[str, Any]:
  14. """Extract a JSON object from a model response."""
  15. if not text:
  16. return {}
  17. stripped = text.strip()
  18. fenced_match = _FENCED_JSON_RE.search(stripped)
  19. if fenced_match:
  20. stripped = fenced_match.group(1).strip()
  21. try:
  22. value = json.loads(stripped)
  23. return value if isinstance(value, dict) else {}
  24. except json.JSONDecodeError:
  25. pass
  26. start = stripped.find("{")
  27. end = stripped.rfind("}")
  28. if start >= 0 and end > start:
  29. fragment = stripped[start:end + 1]
  30. try:
  31. value = json.loads(fragment)
  32. return value if isinstance(value, dict) else {}
  33. except json.JSONDecodeError:
  34. # Retry with control characters escaped (common when model
  35. # emits literal newlines/tabs inside string values).
  36. repaired = _repair_control_chars(fragment)
  37. if repaired != fragment:
  38. try:
  39. value = json.loads(repaired)
  40. return value if isinstance(value, dict) else {}
  41. except json.JSONDecodeError:
  42. pass
  43. return {}
  44. def extract_answer_field(text: str) -> Optional[str]:
  45. """Best-effort extraction of the "answer" field from a raw LLM response.
  46. Used as a fallback when ``extract_json_object`` fails to parse the full
  47. JSON (e.g. due to unescaped control characters in streamed output).
  48. """
  49. if not text:
  50. return None
  51. match = _ANSWER_FIELD_RE.search(text)
  52. if not match:
  53. return None
  54. raw_value = match.group(1)
  55. # Unescape standard JSON escape sequences.
  56. try:
  57. return json.loads(f'"{raw_value}"')
  58. except json.JSONDecodeError:
  59. return raw_value
  60. def _repair_control_chars(s: str) -> str:
  61. """Replace literal control chars inside JSON string values.
  62. Models sometimes emit raw newlines / tabs inside string literals,
  63. which ``json.loads`` rejects. This replaces them with proper escapes
  64. while leaving the surrounding JSON structure intact.
  65. """
  66. # Only replace control characters that appear between quotes.
  67. # A simple approach: replace all bare \n/\r/\t with escaped versions,
  68. # but skip already-escaped sequences (preceded by backslash).
  69. result = []
  70. i = 0
  71. in_string = False
  72. while i < len(s):
  73. c = s[i]
  74. if c == '"' and (i == 0 or s[i - 1] != "\\"):
  75. in_string = not in_string
  76. result.append(c)
  77. elif in_string and c == "\n":
  78. result.append("\\n")
  79. elif in_string and c == "\r":
  80. result.append("\\r")
  81. elif in_string and c == "\t":
  82. result.append("\\t")
  83. else:
  84. result.append(c)
  85. i += 1
  86. return "".join(result)
  87. def compact_json(value: Any) -> str:
  88. return json.dumps(value, ensure_ascii=False, indent=2)