Преглед изворни кода

优化停止按钮,修复调用重复回答

lxylxy123321 пре 1 дан
родитељ
комит
bf9c21044f
4 измењених фајлова са 60 додато и 24 уклоњено
  1. 2 2
      backend/app/api/deployment.py
  2. 32 13
      backend/app/core/inference_worker.py
  3. 25 8
      frontend/src/pages/Deployment.tsx
  4. 1 1
      test.py

+ 2 - 2
backend/app/api/deployment.py

@@ -176,7 +176,7 @@ async def proxy_chat_completions(task_id: str, request: Request):
         "temperature": body.get("temperature", 0.7),
         "top_p": body.get("top_p", 0.9),
         "do_sample": body.get("temperature", 0.7) > 0,
-        "repetition_penalty": body.get("repetition_penalty", 1.0),
+        "repetition_penalty": body.get("repetition_penalty", 1.1),
     }
 
     worker_resp = await deploy_service.proxy_to_worker(task_id, worker_req)
@@ -229,7 +229,7 @@ async def proxy_completions(task_id: str, request: Request):
         "temperature": body.get("temperature", 0.7),
         "top_p": body.get("top_p", 0.9),
         "do_sample": body.get("temperature", 0.7) > 0,
-        "repetition_penalty": body.get("repetition_penalty", 1.0),
+        "repetition_penalty": body.get("repetition_penalty", 1.1),
     }
 
     worker_resp = await deploy_service.proxy_to_worker(task_id, worker_req)

+ 32 - 13
backend/app/core/inference_worker.py

@@ -34,8 +34,17 @@ import threading
 import sys
 
 
-def _build_prompt_from_messages(messages: list[dict]) -> str:
-    """将 OpenAI 消息格式转为模型输入文本。"""
+def _build_prompt_from_messages(tokenizer, messages: list[dict]) -> str:
+    """将 OpenAI 消息格式转为模型输入文本。
+
+    优先使用 tokenizer 自带的 apply_chat_template(Qwen3.5 等模型内建了正确的模板),
+    只有当 tokenizer 没有 chat_template 时才回退到手动拼接。
+    """
+    if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
+        return tokenizer.apply_chat_template(
+            messages, tokenize=False, add_generation_prompt=True
+        )
+    # 回退:手动拼接(兼容没有 chat_template 的模型)
     parts = []
     for msg in messages:
         role = msg.get("role", "")
@@ -51,16 +60,21 @@ def _build_prompt_from_messages(messages: list[dict]) -> str:
 
 
 def _build_stop_criteria(tokenizer, model_device):
-    """构建 StoppingCriteria,遇到角色切换标记时停止生成,防止复读。"""
+    """构建 StoppingCriteria,遇到角色切换标记或 eos 时停止生成,防止复读。"""
     from transformers import StoppingCriteria, StoppingCriteriaList
 
-    # 当模型开始生成下一个 role 标记时就应该停止
-    stop_phrases = ["<|user|>", "<|system|>", "<|assistant|>"]
-    # 预编码 stop 短语,用于精确匹配
+    # 收集所有 stop 短语
+    stop_phrases = ["<|im_end|>", "<|endoftext|>", "<|eob|>", "<|eol|>", "<|user|>", "<|system|>", "<|assistant|>"]
+
     stop_token_ids = []
     for phrase in stop_phrases:
         ids = tokenizer.encode(phrase, add_special_tokens=False)
-        stop_token_ids.append(ids)
+        if ids:
+            stop_token_ids.append(ids)
+
+    # 也加入 eos_token_id(如果有)
+    if tokenizer.eos_token_id is not None:
+        stop_token_ids.append([tokenizer.eos_token_id])
 
     class StopOnRoleToken(StoppingCriteria):
         def __init__(self, stop_sequences, device):
@@ -104,7 +118,7 @@ class InferenceWorker:
         # 支持两种输入:messages(OpenAI 格式)或 prompt(原始文本)
         messages = request.get("messages")
         if messages:
-            prompt = _build_prompt_from_messages(messages)
+            prompt = _build_prompt_from_messages(self.tokenizer, messages)
         else:
             prompt = request.get("prompt", "")
 
@@ -136,11 +150,16 @@ class InferenceWorker:
         generated = self.tokenizer.decode(
             outputs[0][prompt_tokens:], skip_special_tokens=True
         )
-        # 清理可能残留的角色标记(防止 StoppingCriteria 触发前的部分 token)
-        for marker in ["<|user|>", "<|system|>", "<|assistant|>"]:
-            if marker in generated:
-                generated = generated[:generated.index(marker)]
-        generated = generated.strip()
+        # 文本级兜底截断:在生成文本中找到最早的 stop 标记并截断
+        # 防止 StoppingCriteria 因 tokenizer 编码差异未能触发
+        _stop_markers = ["<|eob|>", "<|im_end|>", "<|endoftext|>",
+                        "<|user|>", "<|system|>", "<|assistant|>"]
+        earliest = len(generated)
+        for marker in _stop_markers:
+            idx = generated.find(marker)
+            if idx != -1 and idx < earliest:
+                earliest = idx
+        generated = generated[:earliest].strip()
         completion_tokens = outputs.shape[1] - prompt_tokens
 
         return {

+ 25 - 8
frontend/src/pages/Deployment.tsx

@@ -32,6 +32,7 @@ export function Deployment() {
 
   const pollingRef = useRef<ReturnType<typeof setInterval> | null>(null)
   const servicesPollingRef = useRef<ReturnType<typeof setInterval> | null>(null)
+  const [loadingTaskIds, setLoadingTaskIds] = useState<Set<string>>(new Set())
 
   // 加载 API Keys
   const loadApiKeys = useCallback(() => {
@@ -129,15 +130,19 @@ export function Deployment() {
   }
 
   const handleStop = (taskId: string) => {
+    setLoadingTaskIds(prev => new Set(prev).add(taskId))
     api.deployment.stop(taskId)
       .then(() => loadServices())
       .catch(() => {})
+      .finally(() => setLoadingTaskIds(prev => { const next = new Set(prev); next.delete(taskId); return next }))
   }
 
   const handleRestart = (taskId: string) => {
+    setLoadingTaskIds(prev => new Set(prev).add(taskId))
     api.deployment.restart(taskId)
       .then(() => loadServices())
       .catch(() => {})
+      .finally(() => setLoadingTaskIds(prev => { const next = new Set(prev); next.delete(taskId); return next }))
   }
 
   const tabStyle = (active: boolean): React.CSSProperties => ({
@@ -468,7 +473,7 @@ export function Deployment() {
         ) : (
           <div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
             {services.map(svc => (
-              <ServiceCard key={svc.task_id} service={svc} onStop={() => handleStop(svc.task_id)} onRestart={() => handleRestart(svc.task_id)} />
+              <ServiceCard key={svc.task_id} service={svc} onStop={() => handleStop(svc.task_id)} onRestart={() => handleRestart(svc.task_id)} loading={loadingTaskIds.has(svc.task_id)} />
             ))}
           </div>
         )}
@@ -523,7 +528,7 @@ function TaskStatus({ result }: { result: DeployResponse }) {
   )
 }
 
-function ServiceCard({ service, onStop, onRestart }: { service: DeployedServiceInfo; onStop: () => void; onRestart: () => void }) {
+function ServiceCard({ service, onStop, onRestart, loading }: { service: DeployedServiceInfo; onStop: () => void; onRestart: () => void; loading?: boolean }) {
   const [showUsage, setShowUsage] = useState(false)
   const isRunning = service.status === 'running'
   const isStopped = service.status === 'stopped'
@@ -578,26 +583,38 @@ function ServiceCard({ service, onStop, onRestart }: { service: DeployedServiceI
               </button>
               <button
                 onClick={onStop}
+                disabled={loading}
                 style={{
                   padding: '6px 12px', borderRadius: 6,
-                  border: '1px solid #fca5a5', background: '#fff', color: '#dc2626',
-                  cursor: 'pointer', fontSize: 12, fontWeight: 500,
+                  border: '1px solid #fca5a5',
+                  background: loading ? '#fee2e2' : '#fff',
+                  color: loading ? '#fca5a5' : '#dc2626',
+                  cursor: loading ? 'not-allowed' : 'pointer',
+                  fontSize: 12, fontWeight: 500,
+                  display: 'inline-flex', alignItems: 'center', gap: 4,
                 }}
               >
-                停止
+                {loading && <span style={{ animation: 'spin 1s linear infinite', display: 'inline-block' }}>⟳</span>}
+                {loading ? '停止中...' : '停止'}
               </button>
             </>
           )}
           {isStopped && (
             <button
               onClick={onRestart}
+              disabled={loading}
               style={{
                 padding: '6px 12px', borderRadius: 6,
-                border: '1px solid #86efac', background: '#fff', color: '#16a34a',
-                cursor: 'pointer', fontSize: 12, fontWeight: 500,
+                border: '1px solid #86efac',
+                background: loading ? '#dcfce7' : '#fff',
+                color: loading ? '#86efac' : '#16a34a',
+                cursor: loading ? 'not-allowed' : 'pointer',
+                fontSize: 12, fontWeight: 500,
+                display: 'inline-flex', alignItems: 'center', gap: 4,
               }}
             >
-              重启
+              {loading && <span style={{ animation: 'spin 1s linear infinite', display: 'inline-block' }}>⟳</span>}
+              {loading ? '重启中...' : '重启'}
             </button>
           )}
           {isPending && (

+ 1 - 1
test.py

@@ -7,7 +7,7 @@ client = OpenAI(
 
 response = client.chat.completions.create(
     model="local-model",
-    messages=[{"role": "user", "content": "你"}],
+    messages=[{"role": "user", "content": "你是谁,是哪个模型,详细说一下"}],
     max_tokens=512,
     temperature=0.7
 )