Explorar o código

添加模型图标

lxylxy123321 hai 1 semana
pai
achega
1d831684a8

+ 4 - 2
backend/app/routers/public.py

@@ -23,6 +23,7 @@ class PublicPriceOut(BaseModel):
     model_info: Optional[dict] = None
     rate_limits: Optional[dict] = None
     tool_prices: Optional[list] = None
+    icon: Optional[str] = None
     scraped_at: datetime
 
 
@@ -135,11 +136,11 @@ async def get_public_prices(
     # 从 price_snapshot 读取数据
     if url is None:
         rows = await pool.fetch(
-            "SELECT url, model_name, prices, model_info, rate_limits, tool_prices, updated_at FROM price_snapshot ORDER BY url"
+            "SELECT url, model_name, prices, model_info, rate_limits, tool_prices, icon, updated_at FROM price_snapshot ORDER BY url"
         )
     else:
         rows = await pool.fetch(
-            "SELECT url, model_name, prices, model_info, rate_limits, tool_prices, updated_at FROM price_snapshot WHERE url = $1",
+            "SELECT url, model_name, prices, model_info, rate_limits, tool_prices, icon, updated_at FROM price_snapshot WHERE url = $1",
             url,
         )
         if not rows:
@@ -167,6 +168,7 @@ async def get_public_prices(
         model_info=_j(r["model_info"]),
         rate_limits=_j(r["rate_limits"]),
         tool_prices=_j(r["tool_prices"]),
+        icon=r["icon"],
         scraped_at=r["updated_at"],
     ) for r in rows]
 

+ 9 - 2
backend/app/services/geo.py

@@ -144,7 +144,7 @@ class GeoResolver:
         if not lib:
             return None
         try:
-            # 返回格式: 国家|区域|省|市|ISP  例: 中国|0|四川|成都|电信
+            # 返回格式: 国家|区域|省|市|ISP  例: 中国|0|四川|成都|电信
             result = lib.get_region(ip)
             parts = result.split("|")
             if len(parts) < 5:
@@ -153,8 +153,15 @@ class GeoResolver:
             if country_raw not in ("中国", "中国大陆"):
                 return None
             coords = _lookup_coords(province, city)
+            # 坐标查不到时回退到省会
+            if not coords:
+                prov_clean = province.replace("省","").replace("自治区","").replace("壮族","").replace("回族","").replace("维吾尔","")
+                capital = _PROVINCE_CAPITAL.get(prov_clean)
+                if capital:
+                    coords = _CITY_COORDS.get(capital)
             lat, lon = (coords[0], coords[1]) if coords else (None, None)
-            city_clean = city.replace("市", "").replace("区", "") if city and city != "0" else province.replace("省", "").replace("自治区", "").replace("壮族", "").replace("回族", "").replace("维吾尔", "")
+            city_clean = city.replace("市", "").replace("区", "") if city and city != "0" else \
+                province.replace("省","").replace("自治区","").replace("壮族","").replace("回族","").replace("维吾尔","")
             city_display = _CITY_EN.get(city_clean, city_clean)
             return GeoInfo("China", city_display, lat, lon, isp if isp != "0" else None)
         except Exception:

+ 11 - 8
backend/app/services/scraper.py

@@ -60,7 +60,7 @@ class ScraperService:
                         headless=headless,
                         timeout=20000,
                         executable_path=exec_path,
-                        modules=["info", "rate", "tool", "price"],
+                        modules=["info", "rate", "tool", "price", "icon"],
                     ),
                 )
 
@@ -68,6 +68,7 @@ class ScraperService:
                 model_info  = result.get("info") or {}
                 rate_limits = result.get("rate_limits") or {}
                 tool_prices = result.get("tool_call_prices") or []
+                icon        = result.get("icon")  # SVG string or None
 
                 # model_name: 直接用 URL 中提取的 model_id,保持和用户输入一致
                 model_name = (
@@ -79,18 +80,18 @@ class ScraperService:
                     await conn.execute(
                         """
                         INSERT INTO scrape_results
-                            (job_id, url, model_name, prices, model_info, rate_limits, tool_prices, raw_data)
-                        VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb, $7::jsonb, $8::jsonb)
+                            (job_id, url, model_name, prices, model_info, rate_limits, tool_prices, raw_data, icon)
+                        VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb, $7::jsonb, $8::jsonb, $9)
                         """,
                         job_id, url, model_name,
                         json.dumps(prices), json.dumps(model_info),
                         json.dumps(rate_limits), json.dumps(tool_prices),
-                        json.dumps(result),
+                        json.dumps(result), icon,
                     )
 
                     # 对比旧快照,有变化才 upsert
                     existing = await conn.fetchrow(
-                        "SELECT prices, model_info, rate_limits, tool_prices FROM price_snapshot WHERE url = $1",
+                        "SELECT prices, model_info, rate_limits, tool_prices, icon FROM price_snapshot WHERE url = $1",
                         url,
                     )
                     data_changed = (
@@ -99,6 +100,7 @@ class ScraperService:
                         or _norm(existing["model_info"])  != _norm(model_info)
                         or _norm(existing["rate_limits"]) != _norm(rate_limits)
                         or _norm(existing["tool_prices"]) != _norm(tool_prices)
+                        or (existing["icon"] or "") != (icon or "")
                     )
 
                     if data_changed:
@@ -106,19 +108,20 @@ class ScraperService:
                         await conn.execute(
                             """
                             INSERT INTO price_snapshot
-                                (url, model_name, prices, model_info, rate_limits, tool_prices, updated_at)
-                            VALUES ($1, $2, $3::jsonb, $4::jsonb, $5::jsonb, $6::jsonb, NOW())
+                                (url, model_name, prices, model_info, rate_limits, tool_prices, icon, updated_at)
+                            VALUES ($1, $2, $3::jsonb, $4::jsonb, $5::jsonb, $6::jsonb, $7, NOW())
                             ON CONFLICT (url) DO UPDATE SET
                                 model_name  = EXCLUDED.model_name,
                                 prices      = EXCLUDED.prices,
                                 model_info  = EXCLUDED.model_info,
                                 rate_limits = EXCLUDED.rate_limits,
                                 tool_prices = EXCLUDED.tool_prices,
+                                icon        = EXCLUDED.icon,
                                 updated_at  = NOW()
                             """,
                             url, model_name,
                             json.dumps(prices), json.dumps(model_info),
-                            json.dumps(rate_limits), json.dumps(tool_prices),
+                            json.dumps(rate_limits), json.dumps(tool_prices), icon,
                         )
 
             # 删除 snapshot 里不在本次爬取列表中的行(模型被移除的情况)

+ 7 - 1
backend/crawl/main.py

@@ -45,6 +45,7 @@ from scrape_tool_prices import (
 from scrape_aliyun_models import (
     scrape_model_price,
 )
+from scrape_model_icon import _extract_icon_from_page
 
 
 def _navigate(page, url: str, timeout: int) -> bool:
@@ -94,7 +95,7 @@ def scrape_all(
     默认全部运行。
     """
     if modules is None:
-        modules = ["info", "rate", "tool", "price"]
+        modules = ["info", "rate", "tool", "price", "icon"]
 
     target = _extract_model_id_from_url(url)
     result: Dict = {"url": url, "model_id": target, "error": None}
@@ -173,6 +174,11 @@ def scrape_all(
                     tool_text = _get_tool_price_section_text(html)
                     result["tool_call_prices"] = parse_tool_prices_from_text(tool_text) if tool_text else []
 
+                # ── icon 模块 ──
+                if "icon" in shared_modules:
+                    icon = _extract_icon_from_page(page)
+                    result["icon"] = icon.get("data") if icon.get("type") != "none" else None
+
                 browser.close()
 
     # ── price 模块(原始脚本,独立浏览器) ──────────────────────────────────────

+ 243 - 0
backend/crawl/scrape_model_icon.py

@@ -0,0 +1,243 @@
+#!/usr/bin/env python3
+"""
+Aliyun Model Icon Scraper
+用 Playwright 渲染模型详情页,从 DOM 中提取模型图标(SVG 或 img)。
+
+用法:
+  python scrape_model_icon.py --url "https://bailian.console.aliyun.com/.../qwen3-max"
+  python scrape_model_icon.py --url "..." --save-svg icons/qwen3-max.svg
+  python scrape_model_icon.py --url "..." --screenshot icons/qwen3-max.png
+"""
+
+import argparse
+import json
+import os
+import re
+import time
+from typing import Optional, Dict
+
+from playwright.sync_api import sync_playwright, TimeoutError as PlaywrightTimeoutError
+
+
+# 按优先级依次尝试的选择器
+# 卡片大图标区域(模型详情页左上角)
+_ICON_SELECTORS = [
+    # 模型详情页 header 里的图标容器
+    '[class*="modelIcon"] svg',
+    '[class*="modelIcon"] img',
+    '[class*="model-icon"] svg',
+    '[class*="model-icon"] img',
+    # 面包屑里的小图标(备用)
+    '[class*="currentModelIcon"]',
+    # 通用:页面 header 区域第一个 svg
+    '.pageHeader svg',
+    '[class*="pageHeader"] svg',
+    # 最后兜底:页面内第一个尺寸合理的 svg
+]
+
+
+def _extract_icon_from_page(page) -> Dict:
+    """
+    在已渲染的 page 上提取图标。
+    返回 {"type": "svg"|"img"|"none", "data": str, "selector": str}
+    """
+    result = page.evaluate(
+        """
+        () => {
+            const selectors = [
+                '[class*="modelIcon"] svg',
+                '[class*="modelIcon"] img',
+                '[class*="model-icon"] svg',
+                '[class*="model-icon"] img',
+                '[class*="currentModelIcon"]',
+                '.pageHeader svg',
+                '[class*="pageHeader"] svg',
+            ];
+
+            const isVisible = (el) => {
+                if (!el) return false;
+                const r = el.getBoundingClientRect();
+                const s = window.getComputedStyle(el);
+                return r.width > 0 && r.height > 0
+                    && s.display !== 'none'
+                    && s.visibility !== 'hidden'
+                    && s.opacity !== '0';
+            };
+
+            for (const sel of selectors) {
+                const el = document.querySelector(sel);
+                if (!el || !isVisible(el)) continue;
+
+                if (el.tagName.toLowerCase() === 'svg') {
+                    // 克隆并清理,确保 SVG 有 xmlns
+                    const clone = el.cloneNode(true);
+                    if (!clone.getAttribute('xmlns')) {
+                        clone.setAttribute('xmlns', 'http://www.w3.org/2000/svg');
+                    }
+                    return { type: 'svg', data: clone.outerHTML, selector: sel };
+                }
+
+                if (el.tagName.toLowerCase() === 'img') {
+                    return { type: 'img', data: el.src || el.getAttribute('src'), selector: sel };
+                }
+
+                // 容器里找 svg/img
+                const svg = el.querySelector('svg');
+                if (svg && isVisible(svg)) {
+                    const clone = svg.cloneNode(true);
+                    if (!clone.getAttribute('xmlns')) {
+                        clone.setAttribute('xmlns', 'http://www.w3.org/2000/svg');
+                    }
+                    return { type: 'svg', data: clone.outerHTML, selector: sel + ' > svg' };
+                }
+                const img = el.querySelector('img');
+                if (img && isVisible(img)) {
+                    return { type: 'img', data: img.src || img.getAttribute('src'), selector: sel + ' > img' };
+                }
+            }
+
+            // 兜底:找页面内所有 svg,取尺寸在 24~200px 之间的第一个
+            const allSvgs = Array.from(document.querySelectorAll('svg'));
+            for (const svg of allSvgs) {
+                if (!isVisible(svg)) continue;
+                const r = svg.getBoundingClientRect();
+                if (r.width >= 24 && r.width <= 200 && r.height >= 24 && r.height <= 200) {
+                    const clone = svg.cloneNode(true);
+                    if (!clone.getAttribute('xmlns')) {
+                        clone.setAttribute('xmlns', 'http://www.w3.org/2000/svg');
+                    }
+                    return { type: 'svg', data: clone.outerHTML, selector: 'svg[fallback]' };
+                }
+            }
+
+            return { type: 'none', data: null, selector: null };
+        }
+        """
+    )
+    return result or {"type": "none", "data": None, "selector": None}
+
+
+def scrape_model_icon(
+    url: str,
+    headless: bool = True,
+    timeout: int = 20000,
+    executable_path: Optional[str] = None,
+    save_svg: Optional[str] = None,
+    screenshot: Optional[str] = None,
+) -> Dict:
+    """
+    抓取模型图标。
+
+    返回:
+    {
+      "url": str,
+      "icon_type": "svg" | "img" | "none",
+      "icon_data": str,   # SVG outerHTML 或 img src URL
+      "selector": str,    # 命中的选择器
+      "error": str | None
+    }
+    """
+    result = {"url": url, "icon_type": "none", "icon_data": None, "selector": None, "error": None}
+
+    with sync_playwright() as p:
+        launch_kwargs = {"headless": headless}
+        if executable_path:
+            launch_kwargs["executable_path"] = executable_path
+
+        browser = p.chromium.launch(**launch_kwargs)
+        page = browser.new_context().new_page()
+
+        try:
+            page.goto(url, wait_until="networkidle", timeout=timeout)
+        except PlaywrightTimeoutError:
+            try:
+                page.goto(url, wait_until="load", timeout=timeout)
+            except Exception as e:
+                result["error"] = f"导航失败: {e}"
+                browser.close()
+                return result
+
+        # 等待页面主体内容出现
+        for sel in ["text=模型介绍", "text=模型价格", '[class*="modelIcon"]', '[class*="pageHeader"]']:
+            try:
+                page.wait_for_selector(sel, timeout=6000)
+                break
+            except PlaywrightTimeoutError:
+                pass
+        time.sleep(1.0)
+
+        icon = _extract_icon_from_page(page)
+        result["icon_type"] = icon["type"]
+        result["icon_data"] = icon["data"]
+        result["selector"] = icon["selector"]
+
+        # 如果是 img 且 src 是 SVG URL,直接下载内容转成 svg_data
+        if icon["type"] == "img" and icon["data"] and icon["data"].endswith(".svg"):
+            try:
+                import urllib.request
+                with urllib.request.urlopen(icon["data"], timeout=10) as resp:
+                    svg_content = resp.read().decode("utf-8")
+                result["icon_type"] = "svg"
+                result["icon_data"] = svg_content
+                result["icon_url"] = icon["data"]
+                icon = {**icon, "type": "svg", "data": svg_content}
+            except Exception as e:
+                result["fetch_error"] = str(e)
+
+        # 保存 SVG 文件
+        if save_svg and icon["type"] == "svg" and icon["data"]:
+            os.makedirs(os.path.dirname(save_svg) or ".", exist_ok=True)
+            with open(save_svg, "w", encoding="utf-8") as f:
+                f.write(icon["data"])
+            result["saved_svg"] = save_svg
+
+        # 截图保存
+        if screenshot:
+            os.makedirs(os.path.dirname(screenshot) or ".", exist_ok=True)
+            try:
+                # 优先截图图标元素本身
+                el = page.locator(
+                    '[class*="modelIcon"]'
+                ).first
+                if el.count() > 0:
+                    el.screenshot(path=screenshot)
+                else:
+                    page.screenshot(path=screenshot, full_page=False)
+                result["saved_screenshot"] = screenshot
+            except Exception as e:
+                result["screenshot_error"] = str(e)
+
+        browser.close()
+
+    return result
+
+
+def main():
+    ap = argparse.ArgumentParser(description="爬取阿里云百炼模型图标(SVG/img)")
+    ap.add_argument("--url", required=True, help="模型详情页 URL")
+    ap.add_argument("--headful", action="store_true", help="有头模式(方便调试)")
+    ap.add_argument("--timeout", type=int, default=20000)
+    ap.add_argument("--browser-path", help="浏览器可执行文件路径")
+    ap.add_argument("--save-svg", help="将 SVG 保存到指定路径,如 icons/qwen3-max.svg")
+    ap.add_argument("--screenshot", help="将图标截图保存为 PNG,如 icons/qwen3-max.png")
+    args = ap.parse_args()
+
+    exec_path = args.browser_path or os.environ.get("PLAYWRIGHT_EXECUTABLE")
+    headless = not args.headful
+    if os.environ.get("PLAYWRIGHT_HEADLESS", "").lower() == "false":
+        headless = False
+
+    result = scrape_model_icon(
+        url=args.url,
+        headless=headless,
+        timeout=args.timeout,
+        executable_path=exec_path,
+        save_svg=args.save_svg,
+        screenshot=args.screenshot,
+    )
+
+    print(json.dumps(result, ensure_ascii=False, indent=2))
+
+
+if __name__ == "__main__":
+    main()

+ 8 - 0
backend/migrations/011_icon.sql

@@ -0,0 +1,8 @@
+-- Migration 011: add icon field to price_snapshot and scrape_results
+SET search_path TO crawl;
+
+ALTER TABLE price_snapshot
+    ADD COLUMN IF NOT EXISTS icon TEXT;
+
+ALTER TABLE scrape_results
+    ADD COLUMN IF NOT EXISTS icon TEXT;