scrape.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from __future__ import annotations
  2. import asyncio
  3. from datetime import datetime
  4. from typing import List, Optional
  5. import json
  6. from fastapi import APIRouter, HTTPException
  7. from pydantic import BaseModel
  8. from app.db import get_pool
  9. from app.services.scraper import ScraperService
  10. router = APIRouter(tags=["scrape"])
  11. _scraper = ScraperService()
  12. class ScrapeRequest(BaseModel):
  13. urls: List[str]
  14. # 可选:url -> api_key 映射,爬取时传递给爬虫
  15. api_keys: Optional[dict] = None
  16. class ScrapeJobOut(BaseModel):
  17. job_id: str
  18. status: str
  19. error: Optional[str] = None
  20. created_at: datetime
  21. class ScrapeResultOut(BaseModel):
  22. url: str
  23. model_name: str
  24. prices: dict
  25. model_info: Optional[dict] = None
  26. rate_limits: Optional[dict] = None
  27. tool_prices: Optional[list] = None
  28. scraped_at: datetime
  29. class ScrapeJobDetailOut(BaseModel):
  30. job_id: str
  31. status: str
  32. error: Optional[str] = None
  33. created_at: datetime
  34. results: Optional[List[ScrapeResultOut]] = None
  35. @router.post("/scrape", response_model=ScrapeJobOut, status_code=202)
  36. async def create_scrape_job(body: ScrapeRequest) -> ScrapeJobOut:
  37. pool = get_pool()
  38. # 如果没有传 api_keys,从数据库中自动查询(通过 api_key_id JOIN api_keys 表)
  39. api_keys = body.api_keys or {}
  40. if not api_keys:
  41. async with pool.acquire() as conn:
  42. rows = await conn.fetch(
  43. """
  44. SELECT m.url, k.key_value
  45. FROM models m
  46. JOIN api_keys k ON k.id = m.api_key_id
  47. WHERE m.url = ANY($1::text[]) AND m.api_key_id IS NOT NULL
  48. """,
  49. body.urls,
  50. )
  51. api_keys = {row["url"]: row["key_value"] for row in rows}
  52. async with pool.acquire() as conn:
  53. row = await conn.fetchrow(
  54. """
  55. INSERT INTO scrape_jobs (urls, status)
  56. VALUES ($1, 'pending')
  57. RETURNING id, status, error, created_at
  58. """,
  59. body.urls,
  60. )
  61. job_id = str(row["id"])
  62. asyncio.create_task(_scraper.run_job(job_id, body.urls, pool, api_keys=api_keys))
  63. return ScrapeJobOut(
  64. job_id=job_id,
  65. status=row["status"],
  66. error=row["error"],
  67. created_at=row["created_at"],
  68. )
  69. @router.get("/scrape", response_model=List[ScrapeJobOut])
  70. async def list_scrape_jobs() -> List[ScrapeJobOut]:
  71. pool = get_pool()
  72. async with pool.acquire() as conn:
  73. rows = await conn.fetch(
  74. "SELECT id, status, error, created_at FROM scrape_jobs ORDER BY created_at DESC"
  75. )
  76. return [
  77. ScrapeJobOut(job_id=str(r["id"]), status=r["status"], error=r["error"], created_at=r["created_at"])
  78. for r in rows
  79. ]
  80. @router.get("/scrape/{job_id}", response_model=ScrapeJobDetailOut)
  81. async def get_scrape_job(job_id: str) -> ScrapeJobDetailOut:
  82. pool = get_pool()
  83. async with pool.acquire() as conn:
  84. row = await conn.fetchrow(
  85. "SELECT id, status, error, created_at FROM scrape_jobs WHERE id = $1",
  86. job_id,
  87. )
  88. if row is None:
  89. raise HTTPException(status_code=404, detail="Scrape job not found")
  90. results: Optional[List[ScrapeResultOut]] = None
  91. if row["status"] == "done":
  92. result_rows = await conn.fetch(
  93. "SELECT url, model_name, prices, model_info, rate_limits, tool_prices, scraped_at FROM scrape_results WHERE job_id = $1 ORDER BY scraped_at ASC",
  94. job_id,
  95. )
  96. def _j(v):
  97. if v is None: return None
  98. return v if isinstance(v, (dict, list)) else json.loads(v)
  99. results = [
  100. ScrapeResultOut(
  101. url=r["url"],
  102. model_name=r["model_name"],
  103. prices=_j(r["prices"]) or {},
  104. model_info=_j(r["model_info"]),
  105. rate_limits=_j(r["rate_limits"]),
  106. tool_prices=_j(r["tool_prices"]),
  107. scraped_at=r["scraped_at"],
  108. )
  109. for r in result_rows
  110. ]
  111. return ScrapeJobDetailOut(
  112. job_id=str(row["id"]),
  113. status=row["status"],
  114. error=row["error"],
  115. created_at=row["created_at"],
  116. results=results,
  117. )