| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- """
- 提示词管理 API 端点
- 提供提示词的查询、版本管理、对比、激活和回滚功能。
- 基于 core/debug/prompt_manager.py 的 PromptManager 实现。
- """
- import logging
- from fastapi import APIRouter, Query, HTTPException
- from pydantic import ValidationError
- from core.debug.prompt_manager import PromptManager, CHAINS, PROMPT_FILE_MAP, PROMPT_CHAIN_MAP
- logger = logging.getLogger(__name__)
- # 全局 PromptManager 实例
- _prompt_manager = None
- def _get_manager() -> PromptManager:
- """获取或创建 PromptManager 单例"""
- global _prompt_manager
- if _prompt_manager is None:
- _prompt_manager = PromptManager()
- return _prompt_manager
- def register_routes(router: APIRouter):
- """在指定 router 上注册提示词管理路由"""
- @router.get(
- '/api/prompts',
- summary='获取提示词列表',
- description='获取所有提示词及其版本信息,支持链路筛选和名称搜索。',
- )
- async def list_prompts(
- chain: str = Query(default=None, description='链路筛选'),
- search: str = Query(default=None, description='名称搜索'),
- page: int = Query(default=1, ge=1, description='页码'),
- page_size: int = Query(default=50, ge=1, le=200, description='每页条数'),
- ):
- """
- 获取所有提示词及其版本信息。
- """
- try:
- manager = _get_manager()
- items = manager.get_all_prompts(chain_filter=chain, search=search)
- # 分页
- total = len(items)
- start = (page - 1) * page_size
- paged_items = items[start:start + page_size]
- return {
- 'status': 'ok',
- 'total': total,
- 'page': page,
- 'page_size': page_size,
- 'items': paged_items,
- 'chains': CHAINS,
- }
- except Exception as e:
- logger.error('获取提示词列表失败: %s', e)
- raise HTTPException(status_code=500, detail=str(e))
- @router.get(
- '/api/prompts/{name}',
- summary='获取提示词详情',
- description='获取指定提示词指定版本的完整详情,包括 system_prompt 和 user_prompt。',
- )
- async def get_prompt_detail(
- name: str,
- version: str = Query(default=None, description='版本号,不指定则返回当前激活版本'),
- ):
- """
- 获取指定提示词的完整详情,包括系统提示词和用户提示词模板。
- """
- try:
- manager = _get_manager()
- detail = manager.get_prompt_detail(name, version=version)
- if detail is None:
- raise HTTPException(
- status_code=404,
- detail=f'提示词不存在: {name}',
- )
- return {
- 'status': 'ok',
- **detail,
- }
- except HTTPException:
- raise
- except Exception as e:
- logger.error('获取提示词详情失败: %s', e)
- raise HTTPException(status_code=500, detail=str(e))
- @router.get(
- '/api/prompts/{name}/versions',
- summary='获取版本列表',
- description='获取指定提示词的所有历史版本。',
- )
- async def list_prompt_versions(name: str):
- """
- 获取指定提示词的所有历史版本。
- """
- try:
- manager = _get_manager()
- # 先检查提示词是否存在
- if name not in PROMPT_FILE_MAP:
- detail = manager.get_prompt_detail(name)
- if detail is None:
- raise HTTPException(
- status_code=404,
- detail=f'提示词不存在: {name}',
- )
- # 获取当前激活版本
- current_info = manager.get_prompt_detail(name)
- current_version = current_info.get('version', '') if current_info else ''
- versions = manager.get_versions(name)
- chain = PROMPT_CHAIN_MAP.get(name, '')
- return {
- 'status': 'ok',
- 'name': name,
- 'chain': chain,
- 'current_version': current_version,
- 'versions': versions,
- }
- except HTTPException:
- raise
- except Exception as e:
- logger.error('获取版本列表失败: %s', e)
- raise HTTPException(status_code=500, detail=str(e))
- @router.post(
- '/api/prompts/save',
- summary='保存新版本',
- description='保存提示词的新版本。将当前编辑内容保存为新版本,并可选设为当前激活版本。',
- )
- async def save_prompt_version(body: dict):
- """
- 保存提示词的新版本。
- """
- try:
- name = body.get('name', '')
- system_prompt = body.get('system_prompt', '')
- user_prompt = body.get('user_prompt', '')
- note = body.get('note', '')
- set_current = body.get('set_current', True)
- # 参数验证
- if not name:
- raise HTTPException(status_code=422, detail='name 不能为空')
- if not system_prompt:
- raise HTTPException(status_code=422, detail='system_prompt 不能为空')
- if not user_prompt:
- raise HTTPException(status_code=422, detail='user_prompt 不能为空')
- manager = _get_manager()
- result = manager.save_new_version(
- name=name,
- system_prompt=system_prompt,
- user_prompt=user_prompt,
- note=note,
- set_current=set_current,
- )
- return {
- 'success': True,
- 'name': result['name'],
- 'version': result['version'],
- 'time': result['time'],
- 'message': f'已保存新版本 {result["version"]}',
- }
- except HTTPException:
- raise
- except ValueError as e:
- raise HTTPException(status_code=404, detail=str(e))
- except Exception as e:
- logger.error('保存新版本失败: %s', e)
- raise HTTPException(status_code=500, detail=str(e))
- @router.post(
- '/api/prompts/compare',
- summary='版本对比',
- description='对比两个版本的差异(行级 Diff)。',
- )
- async def compare_prompt_versions(body: dict):
- """
- 对比两个版本的差异。
- """
- try:
- name = body.get('name', '')
- base_version = body.get('base_version', '')
- target_version = body.get('target_version', '')
- if not name or not base_version or not target_version:
- raise HTTPException(
- status_code=422,
- detail='name, base_version, target_version 不能为空',
- )
- manager = _get_manager()
- result = manager.compare_versions(name, base_version, target_version)
- return {
- 'status': 'ok',
- **result,
- }
- except HTTPException:
- raise
- except FileNotFoundError as e:
- raise HTTPException(status_code=404, detail=str(e))
- except Exception as e:
- logger.error('版本对比失败: %s', e)
- raise HTTPException(status_code=500, detail=str(e))
- @router.post(
- '/api/prompts/activate',
- summary='激活版本',
- description='将指定版本设为当前激活版本(覆盖写入主 YAML 文件)。',
- )
- async def activate_prompt_version(body: dict):
- """
- 将指定版本设为当前激活版本。
- """
- try:
- name = body.get('name', '')
- version = body.get('version', '')
- if not name or not version:
- raise HTTPException(
- status_code=422,
- detail='name, version 不能为空',
- )
- manager = _get_manager()
- result = manager.activate_version(name, version)
- return {
- 'success': result['success'],
- 'name': result['name'],
- 'version': result['version'],
- 'message': f'已激活版本 {result["version"]}',
- }
- except HTTPException:
- raise
- except (ValueError, FileNotFoundError) as e:
- raise HTTPException(status_code=404, detail=str(e))
- except Exception as e:
- logger.error('激活版本失败: %s', e)
- raise HTTPException(status_code=500, detail=str(e))
- @router.post(
- '/api/prompts/rollback',
- summary='回滚版本',
- description='回滚到指定历史版本(等同于将该版本内容设为当前激活版本)。',
- )
- async def rollback_prompt_version(body: dict):
- """
- 回滚到指定历史版本。
- """
- try:
- name = body.get('name', '')
- version = body.get('version', '')
- if not name or not version:
- raise HTTPException(
- status_code=422,
- detail='name, version 不能为空',
- )
- manager = _get_manager()
- result = manager.rollback_version(name, version)
- return {
- 'success': result['success'],
- 'name': result['name'],
- 'version': result['version'],
- 'message': f'已回滚到版本 {result["version"]}',
- }
- except HTTPException:
- raise
- except (ValueError, FileNotFoundError) as e:
- raise HTTPException(status_code=404, detail=str(e))
- except Exception as e:
- logger.error('回滚版本失败: %s', e)
- raise HTTPException(status_code=500, detail=str(e))
|