Przeglądaj źródła

向量库案例测试

lingmin_package@163.com 3 miesięcy temu
rodzic
commit
3859ab37b5

+ 7 - 0
README.md

@@ -25,6 +25,13 @@
   pip install aiomysql -i https://mirrors.aliyun.com/pypi/simple/
 
 
+
+### 向量模型和重排序模型测试
+  cd LQAgentPlatform
+  python foundation/models/silicon_flow.py
+
+
+
 ### 测试接口
 
   #### 生成模型接口 

+ 18 - 0
config/config.ini

@@ -86,3 +86,21 @@ MYSQL_DB=lq_db
 MYSQL_MIN_SIZE=1
 MYSQL_MAX_SIZE=2
 MYSQL_AUTO_COMMIT=True
+
+
+
+
+[pgvector]
+PGVECTOR_HOST=124.223.140.149
+PGVECTOR_PORT=7432
+PGVECTOR_DB=vector_db
+PGVECTOR_USER=vector_user
+PGVECTOR_PASSWORD=pg16@123
+
+
+[milvus]
+MILVUS_HOST=124.223.140.149
+MILVUS_PORT=7432
+MILVUS_DB=vector_db
+MILVUS_USER=vector_user
+MILVUS_PASSWORD=pg16@123

+ 2 - 35
foundation/models/silicon_flow.py

@@ -12,7 +12,7 @@ from foundation.logger.loggering import server_logger
 from foundation.utils.common import handler_err
 from openai import OpenAI
 from langchain_core.embeddings import Embeddings
-from chromadb.utils.embedding_functions import EmbeddingFunction
+#from chromadb.utils.embedding_functions import EmbeddingFunction
 from typing import List
 import numpy as np
 
@@ -55,39 +55,6 @@ class SiliconFlowEmbeddings(Embeddings):
 
 
 
-class ChromaSiliconFlowEmbedding(EmbeddingFunction):
-    """
-        将SiliconFlowEmbeddings适配到ChromaDB的嵌入函数接口
-    """
-    def __init__(self, embeddings):
-        self.embeddings = embeddings
-
-    def __call__(self, input: List[str]) -> List[List[float]]:
-        raw_embeddings = self.embeddings.embed_documents(input)  # 关键添加
-        return self.normalized_embeddings(raw_embeddings)
-
-    def embed_documents(self, input: List[str]) -> List[List[float]]:
-        raw_embeddings = self.embeddings.embed_documents(input)  # 关键添加
-        return self.normalized_embeddings(raw_embeddings)
-
-    def embed_query(self, text: str) -> List[float]:
-        """对查询文本进行向量化"""
-        raw_embeddings = self.embeddings.embed_documents([text])[0]
-        return self.normalized_embeddings(raw_embeddings)
-
-    
-    def normalized_embeddings(self , raw_embeddings):
-        # L2归一化处理
-        normalized = []
-        for vector in raw_embeddings:
-            norm = np.linalg.norm(vector)
-            if norm > 0:
-                normalized.append(vector / norm)
-            else:
-                normalized.append(vector)
-        return normalized
-
-
 
 class SiliconFlowAPI(BaseApiPlatform):
     def __init__(self , trace_id=""):
@@ -104,7 +71,7 @@ class SiliconFlowAPI(BaseApiPlatform):
         self.client = self.get_openai_client(self.model_server_url, self.api_key)
         # 创建LangChain兼容的嵌入对象
         langchain_embeddings = SiliconFlowEmbeddings(base_url = self.model_server_url , api_key=self.api_key , embed_model_id=self.embed_model_id)
-        self.embed_model = ChromaSiliconFlowEmbedding(embeddings=langchain_embeddings)
+        #self.embed_model = ChromaSiliconFlowEmbedding(embeddings=langchain_embeddings)
 
 
 

+ 108 - 0
foundation/rag/vector/base_vector.py

@@ -0,0 +1,108 @@
+from foundation.logger.loggering import server_logger as logger
+import os
+import time
+from tqdm import tqdm
+from typing import List, Dict, Any
+from foundation.models.base_online_platform import BaseApiPlatform
+
+
+class BaseVectorDB:
+    """
+      向量数据库操作基类
+    """
+        
+    def __init__(self , base_api_platform :BaseApiPlatform):
+        self.base_api_platform = base_api_platform
+
+
+
+    def text_to_vector(self, text: str) -> List[float]:
+        """
+        将文本转换为向量
+        """
+        return self.base_api_platform.get_embeddings([text])[0]
+    
+
+    def document_standard(self, documents: List[Dict[str, Any]]):
+        """
+          文档标准处理
+        """
+        raise NotImplementedError
+
+    
+    def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
+        """
+          单条添加文档
+          param: 扩展参数信息,如:表名称等
+          documents: 文档列表,包括元数据信息
+          # 返回: 添加的文档ID列表
+        """
+        raise NotImplementedError
+
+
+    def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
+        """
+          批量添加文档
+          param: 扩展参数信息,如:表名称等
+          documents: 文档列表,包括元数据信息
+          # 返回: 添加的文档ID列表
+        """
+        raise NotImplementedError
+
+
+    def add_tqdm_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]] , batch_size=10):
+        """
+          批量添加文档(带进度条)
+          param: 扩展参数信息,如:表名称等
+          documents: 文档列表,包括元数据信息
+          # 返回: 添加的文档ID列表
+        """
+        
+        logger.info(f"Inserting {len(documents)} documents.")
+        start_time = time.time()
+        total_docs_inserted = 0
+
+        total_batches = (len(documents) + batch_size - 1) // batch_size
+
+        with tqdm(total=total_batches, desc="Inserting batches", unit="batch") as pbar:
+            for i in range(0, len(documents), batch_size):
+                batch = documents[i:i + batch_size]
+                # 调用传入的插入函数
+                self.add_batch_documents(param, batch)
+
+                total_docs_inserted += len(batch)
+                # 计算并显示当前的TPM
+                elapsed_time = time.time() - start_time
+                if elapsed_time > 0:
+                    tpm = (total_docs_inserted / elapsed_time) * 60
+                    pbar.set_postfix({"TPM": f"{tpm:.2f}"})
+
+                pbar.update(1)
+
+        
+
+
+    def retriever(self, input_query):
+        """
+          根据用户问题查询文档
+        """
+        raise NotImplementedError
+
+
+    def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 , 
+                          top_k=10, filters: Dict[str, Any] = None):
+      """
+          根据用户问题查询文档
+      """
+      raise NotImplementedError
+
+
+    def retriever(self, param: Dict[str, Any], query_text: str, 
+                          top_k: int = 5, filters: Dict[str, Any] = None):
+      """
+          根据用户问题查询文档
+      """
+      raise NotImplementedError
+
+
+    

+ 347 - 0
foundation/rag/vector/milvus_vector.py

@@ -0,0 +1,347 @@
+import time
+from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
+from sentence_transformers import SentenceTransformer
+import numpy as np
+from typing import List, Dict, Any, Optional
+import json
+from foundation.base.config import config_handler
+from foundation.logger.loggering import server_logger as logger
+from foundation.rag.vector.base_vector import BaseVectorDB
+from foundation.models.base_online_platform import BaseApiPlatform
+
+class MilvusVectorManager(BaseVectorDB):
+    def __init__(self, base_api_platform :BaseApiPlatform):
+        """
+        初始化 Milvus 连接
+        """
+        self.base_api_platform = base_api_platform
+
+        self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
+        self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
+        self.user = config_handler.get('milvus', 'MILVUS_USER')
+        self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
+        
+        # 初始化文本向量化模型
+        #self.model = SentenceTransformer('all-MiniLM-L6-v2')  # 可以替换为其他模型
+        
+        # 连接到 Milvus
+        self.connect()
+    
+    def connect(self):
+        """连接到 Milvus 服务器"""
+        try:
+            connections.connect(
+                alias="default",
+                host=self.host,
+                port=self.port,
+                user=self.user,
+                password=self.password
+            )
+            logger.info(f"Connected to Milvus at {self.host}:{self.port}")
+        except Exception as e:
+            logger.error(f"Failed to connect to Milvus: {e}")
+            raise
+    
+    def create_collection(self, collection_name: str, dimension: int = 768, 
+                         description: str = "Vector collection for text embeddings"):
+        """
+        创建向量集合
+        """
+        try:
+            # 检查集合是否已存在
+            if utility.has_collection(collection_name):
+                logger.info(f"Collection {collection_name} already exists")
+                return
+            
+            # 定义字段
+            fields = [
+                FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
+                FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
+                FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
+                FieldSchema(name="metadata", dtype=DataType.JSON),
+                FieldSchema(name="created_at", dtype=DataType.INT64)
+            ]
+            
+            # 创建集合模式
+            schema = CollectionSchema(
+                fields=fields,
+                description=description
+            )
+            
+            # 创建集合
+            collection = Collection(
+                name=collection_name,
+                schema=schema
+            )
+            
+            # 创建索引
+            index_params = {
+                "index_type": "IVF_FLAT",
+                "metric_type": "COSINE",
+                "params": {"nlist": 100}
+            }
+            
+            collection.create_index(field_name="embedding", index_params=index_params)
+            logger.info(f"Collection {collection_name} created successfully!")
+            
+        except Exception as e:
+            logger.error(f"Error creating collection: {e}")
+            raise
+    
+    
+    
+
+    def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
+        """
+        插入单个文本及其向量
+        """
+        try:
+            collection_name = param.get('collection_name')
+            text = document.get('content')
+            metadata = document.get('metadata')
+            collection = Collection(collection_name)
+            created_at = None
+            
+            # 转换文本为向量
+            embedding = self.text_to_vector(text)
+            
+            # 准备数据
+            data = [
+                [text],  # text_content
+                [embedding],  # embedding
+                [metadata or {}],  # metadata
+                [created_at or int(time.time())]  # created_at
+            ]
+            
+            # 插入数据
+            insert_result = collection.insert(data)
+            collection.flush()  # 确保数据被写入
+            
+            logger.info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
+            return insert_result.primary_keys[0]
+            
+        except Exception as e:
+            logger.error(f"Error inserting text: {e}")
+            return None
+    
+
+
+    def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
+        """
+        批量插入文本
+        texts: [{'text': '...', 'metadata': {...}}, ...]
+        """
+        try:
+            collection_name = param.get('collection_name')
+            collection = Collection(collection_name)
+            
+            text_contents = []
+            embeddings = []
+            metadatas = []
+            timestamps = []
+            
+            for item in documents:
+                text = item['content']
+                metadata = item.get('metadata', {})
+                
+                # 转换文本为向量
+                embedding = self.text_to_vector(text)
+                
+                text_contents.append(text)
+                embeddings.append(embedding)
+                metadatas.append(metadata)
+                timestamps.append(int(time.time()))
+            
+            # 准备批量数据
+            data = [text_contents, embeddings, metadatas, timestamps]
+            
+            # 批量插入
+            insert_result = collection.insert(data)
+            collection.flush()  # 确保数据被写入
+            
+            logger.info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
+            return insert_result.primary_keys
+            
+        except Exception as e:
+            logger.error(f"Error batch inserting: {e}")
+            return None
+    
+
+
+
+    def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
+                           top_k=5, filters: Dict[str, Any] = None):
+        """
+        搜索相似文本
+        """
+        try:
+            collection_name = param.get('collection_name')
+            collection = Collection(collection_name)
+            
+            # 加载集合到内存(如果还没有加载)
+            collection.load()
+            
+            # 转换查询文本为向量
+            query_embedding = self.text_to_vector(query_text)
+            
+            # 搜索参数
+            search_params = {
+                "metric_type": "COSINE",
+                "params": {"nprobe": 10}
+            }
+             # 构建过滤表达式
+            filter_expr = self._create_filter(filters)
+            
+            # 执行搜索
+            results = collection.search(
+                data=[query_embedding],
+                anns_field="embedding",
+                param=search_params,
+                limit=top_k,
+                expr=filter_expr,
+                output_fields=["text_content", "metadata"]
+            )
+            
+            # 格式化结果
+            formatted_results = []
+            for hits in results:
+                for hit in hits:
+                    formatted_results.append({
+                        'id': hit.id,
+                        'text_content': hit.entity.get('text_content'),
+                        'metadata': hit.entity.get('metadata'),
+                        'distance': hit.distance,
+                        'similarity': 1 - hit.distance  # 转换为相似度
+                    })
+            
+            return formatted_results
+            
+        except Exception as e:
+            logger.error(f"Error searching: {e}")
+            return []
+    
+    def retriever(self, param: Dict[str, Any], query_text: str, 
+                          top_k: int = 5, filters: Dict[str, Any] = None):
+        """
+        带过滤条件的相似搜索
+        """
+        try:
+            collection_name = param.get('collection_name')
+            collection = Collection(collection_name)
+            collection.load()
+            
+            query_embedding = self.text_to_vector(query_text)
+            
+            # 构建过滤表达式
+            filter_expr = self._create_filter(filters)
+            
+            search_params = {
+                "metric_type": "COSINE",
+                "params": {"nprobe": 10}
+            }
+            
+            results = collection.search(
+                data=[query_embedding],
+                anns_field="embedding",
+                param=search_params,
+                limit=top_k,
+                expr=filter_expr,
+                output_fields=["text_content", "metadata"]
+            )
+            
+            formatted_results = []
+            for hits in results:
+                for hit in hits:
+                    formatted_results.append({
+                        'id': hit.id,
+                        'text_content': hit.entity.get('text_content'),
+                        'metadata': hit.entity.get('metadata'),
+                        'distance': hit.distance,
+                        'similarity': 1 - hit.distance
+                    })
+            
+            return formatted_results
+            
+        except Exception as e:
+            logger.error(f"Error searching with filter: {e}")
+            return []
+    
+    
+    def _create_filter(self, filters: Dict[str, Any]) -> str:
+        """
+        创建过滤条件
+        """
+        # 构建过滤表达式
+        filter_expr = ""
+        if filters:
+            conditions = []
+            for key, value in filters.items():
+                if isinstance(value, str):
+                    conditions.append(f'metadata["{key}"] == "{value}"')
+                elif isinstance(value, (int, float)):
+                    conditions.append(f'metadata["{key}"] == {value}')
+                else:
+                    conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
+            filter_expr = " and ".join(conditions)
+        
+        return filter_expr
+
+    def db_test(self):
+        import time
+        # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
+        from foundation.models.silicon_flow import SiliconFlowAPI
+        client = SiliconFlowAPI()
+        # 初始化 Milvus 管理器
+        milvus_manager = MilvusVectorManager(base_api_platform=client)
+        
+        # 创建集合
+        collection_name = 'text_embeddings'
+        milvus_manager.create_collection(collection_name, dimension=384)
+        
+        # 插入单个文本
+        sample_text = "这是一个关于人工智能的文档。"
+        milvus_manager.insert_text(
+            collection_name, 
+            sample_text, 
+            metadata={'category': 'AI', 'source': 'example'}
+        )
+        
+        # 批量插入文本
+        sample_texts = [
+            {
+                'text': '机器学习是人工智能的一个重要分支。',
+                'metadata': {'category': 'ML', 'author': 'John'}
+            },
+            {
+                'text': '深度学习在图像识别领域取得了显著成果。',
+                'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
+            },
+            {
+                'text': '自然语言处理技术在聊天机器人中得到广泛应用。',
+                'metadata': {'category': 'NLP', 'author': 'Bob'}
+            }
+        ]
+        
+        param = {"collection_name": collection_name}
+        milvus_manager.add_batch_documents(param, sample_texts)
+        
+        # 搜索相似文本
+        query = "人工智能相关的技术"
+        similar_docs = milvus_manager.similarity_search(param, query, top_k=3)
+        
+        logger.info("Similar documents found:")
+        for doc in similar_docs:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
+        
+        # 带过滤条件的搜索
+        filtered_docs = milvus_manager.search_with_filter(
+            collection_name, 
+            query, 
+            top_k=3, 
+            filters={'category': 'AI'}
+        )
+        
+        logger.info("\nFiltered similar documents:")
+        for doc in filtered_docs:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
+

+ 269 - 0
foundation/rag/vector/pg_vector.py

@@ -0,0 +1,269 @@
+
+import psycopg2
+from psycopg2.extras import RealDictCursor
+import numpy as np
+from sentence_transformers import SentenceTransformer
+import json
+from typing import List, Dict, Any
+from foundation.base.config import config_handler
+from foundation.logger.loggering import server_logger as logger
+from foundation.rag.vector.base_vector import BaseVectorDB
+from foundation.models.base_online_platform import BaseApiPlatform
+
+class PGVectorDB(BaseVectorDB):
+    def __init__(self , base_api_platform :BaseApiPlatform):
+        """
+        初始化 pgvector 连接
+        """
+        self.connection_params = {
+            'host': config_handler.get('pgvector', 'PGVECTOR_HOST', 'localhost'),
+            'port': int(config_handler.get('pgvector', 'PGVECTOR_PORT', '5432')),
+            'database': config_handler.get('pgvector', 'PGVECTOR_DB', 'postgres'),
+            'user': config_handler.get('pgvector', 'PGVECTOR_USER', 'postgres'),
+            'password': config_handler.get('pgvector', 'PGVECTOR_PASSWORD', 'postgres')
+        }
+
+        self.base_api_platform = base_api_platform
+        
+        
+    def get_connection(self):
+        """获取数据库连接"""
+        #logger.info(f"Connecting to PostgreSQL...{self.connection_params}")
+        conn = psycopg2.connect(**self.connection_params)
+        # 启用 pgvector 扩展
+        with conn.cursor() as cur:
+            cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
+        conn.commit()
+        return conn
+    
+    def create_table(self, table_name: str, vector_dim: int = 384):
+        """
+        创建向量表
+        """
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cur:
+                # 创建表
+                create_table_sql = f"""
+                CREATE TABLE IF NOT EXISTS {table_name} (
+                    id SERIAL PRIMARY KEY,
+                    text_content TEXT,
+                    embedding vector({vector_dim}),
+                    metadata JSONB DEFAULT '{{}}'::jsonb,
+                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+                );
+                
+                -- 创建向量相似度索引
+                CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding 
+                ON {table_name} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
+                """
+                cur.execute(create_table_sql)
+                conn.commit()
+                print(f"Table {table_name} created successfully!")
+        except Exception as e:
+            logger.error(f"Error creating table: {e}")
+            conn.rollback()
+        finally:
+            conn.close()
+    
+
+    def document_standard(self, documents: List[Dict[str, Any]]):
+        """
+        对文档进行结果标准处理
+        """
+        result = []
+        for doc in documents:
+            tmp = {}
+            tmp['content'] = doc.page_content
+            tmp['metadata'] = doc.metadata if doc.metadata else {}
+            result.append(tmp)
+        return result
+
+
+
+    def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
+        """
+        插入单个文本及其向量
+        """
+        table_name = param.get('table_name')
+        text = document.get('content')
+        metadata = document.get('metadata')
+
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cur:
+                embedding = self.text_to_vector(text)
+                metadata = metadata or {}
+                
+                insert_sql = f"""
+                INSERT INTO {table_name} (text_content, embedding, metadata)
+                VALUES (%s, %s, %s)
+                RETURNING id;
+                """
+                cur.execute(insert_sql, (text, embedding, json.dumps(metadata)))
+                inserted_id = cur.fetchone()[0]
+                conn.commit()
+                print(f"Text inserted with ID: {inserted_id}")
+                return inserted_id
+        except Exception as e:
+            print(f"Error inserting text: {e}")
+            conn.rollback()
+            return None
+        finally:
+            conn.close()
+    
+    def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
+        """
+        批量插入文本
+        texts: [{'text': '...', 'metadata': {...}}, ...]
+        """
+        table_name = param.get('table_name')
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cur:
+                # 准备数据
+                data_to_insert = []
+                for item in documents:
+                    text = item['content']
+                    metadata = item.get('metadata', {})
+                    embedding = self.text_to_vector(text)
+                    data_to_insert.append((text, embedding, json.dumps(metadata)))
+                
+                # 批量插入
+                insert_sql = f"""
+                INSERT INTO {table_name} (text_content, embedding, metadata)
+                VALUES (%s, %s, %s)
+                """
+                cur.executemany(insert_sql, data_to_insert)
+                conn.commit()
+                logger.info(f"Batch inserted {len(data_to_insert)} records")
+        except Exception as e:
+            logger.error(f"Error batch inserting: {e}")
+            conn.rollback()
+        finally:
+            conn.close()
+    
+    def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 , 
+                          top_k=5, filters: Dict[str, Any] = None):
+        """
+        搜索相似文本
+            search_similar 使用距离度量(越小越相似)
+            
+        """
+        table_name = param.get('table_name')
+        conn = self.get_connection()
+        try:
+            with conn.cursor(cursor_factory=RealDictCursor) as cur:
+                query_embedding = self.text_to_vector(query_text)
+                
+                search_sql = f"""
+                SELECT id, text_content, metadata, 
+                       embedding <=> %s::vector AS distance
+                FROM {table_name}
+                ORDER BY embedding <=> %s::vector
+                LIMIT %s;
+                """
+                cur.execute(search_sql, (query_embedding, query_embedding, top_k))
+                results = cur.fetchall()
+                
+                return results
+        except Exception as e:
+            logger.error(f"Error searching: {e}")
+            return []
+        finally:
+            conn.close()
+
+    
+    def retriever(self, param: Dict[str, Any], query_text: str , min_score=0.1 , 
+                                 top_k=10, filters: Dict[str, Any] = None):
+        """
+        使用余弦相似度搜索相似文本
+        """
+        table_name = param.get('table_name')
+        conn = self.get_connection()
+        try:
+            with conn.cursor(cursor_factory=RealDictCursor) as cur:
+                query_embedding = self.text_to_vector(query_text)
+                
+                search_sql = f"""
+                SELECT id, text_content, metadata,
+                       1 - (embedding <=> %s::vector) AS cosine_similarity
+                FROM {table_name}
+                WHERE 1 - (embedding <=> %s::vector) > %s
+                ORDER BY 1 - (embedding <=> %s::vector) DESC
+                LIMIT %s;
+                """
+                cur.execute(search_sql, (query_embedding, query_embedding, min_score, query_embedding, top_k))
+                results = cur.fetchall()
+                # 打印结果
+                self.result_logger_info(query_text , results)
+                return results
+        except Exception as e:
+            logger.error(f"Error searching with cosine similarity: {e}")
+            return []
+        finally:
+            conn.close()
+
+    
+    def result_logger_info(self , query, result_docs_cos):
+        """
+            记录搜索结果
+        """
+        logger.info(f"\n {'=' * 50}")
+        # 使用余弦相似度搜索
+        logger.info(f"\nSimilar documents with cosine similarity,query:{query},result_count: {len(result_docs_cos)}:")
+        for doc in result_docs_cos:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")
+
+
+
+    def db_test(self , query_text: str):
+        """
+        测试数据库连接和操作
+        """
+        table_name = 'test_documents'
+        # 创建表
+        self.create_table(table_name, vector_dim=768)
+        
+        # 插入单个文本
+        sample_text = "这是一个关于人工智能的文档。"
+        #self.insert_text(table_name, sample_text, {'category': 'AI', 'source': 'example'})
+        
+        # 批量插入文本
+        sample_texts = [
+            {
+                'text': '机器学习是人工智能的一个重要分支。',
+                'metadata': {'category': 'ML', 'author': 'John'}
+            },
+            {
+                'text': '深度学习在图像识别领域取得了显著成果。',
+                'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
+            },
+            {
+                'text': '自然语言处理技术在聊天机器人中得到广泛应用。',
+                'metadata': {'category': 'NLP', 'author': 'Bob'}
+            }
+        ]
+        
+        #self.batch_insert_texts(table_name, sample_texts)
+        
+
+        logger.info(f"\n {'=' * 50}")
+        # 搜索相似文本
+        #query = "人工智能相关的技术"
+        query = query_text
+        logger.info(f"\n query={query}")
+
+        similar_docs = self.search_similar(table_name, query, top_k=3)
+        logger.info(f"Similar documents found {len(similar_docs)}:")
+        for doc in similar_docs:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {1 - doc['distance']:.3f}")
+        
+        logger.info(f"\n {'=' * 50}")
+        # 使用余弦相似度搜索
+        similar_docs_cos = self.search_by_cosine_similarity(table_name, query, top_k=3)
+        
+        logger.info(f"\nSimilar documents with cosine similarity {len(similar_docs_cos)}:")
+        for doc in similar_docs_cos:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")
+

+ 5 - 4
views/__init__.py

@@ -21,10 +21,11 @@ async def lifespan(app: FastAPI):
     # 启动时加载工具
     #await mcp_server.get_mcp_tools()
     # 全局数据库连接池实例
-    async_db_pool = AsyncMySQLPool()
-    await async_db_pool.initialize()
-    app.state.async_db_pool = async_db_pool
-    server_logger.info(f"✅ MySQL数据库连接池:{app.state.async_db_pool}")
+    async_db_pool = None
+    # async_db_pool = AsyncMySQLPool()
+    # await async_db_pool.initialize()
+    # app.state.async_db_pool = async_db_pool
+    #server_logger.info(f"✅ MySQL数据库连接池:{app.state.async_db_pool}")
 
     yield
     # 关闭时清理

+ 133 - 0
views/test_views.py

@@ -26,6 +26,8 @@ from foundation.agent.workflow.test_workflow_graph import test_workflow_graph
 from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
 from database.repositories.bus_data_query import BasisOfPreparationDAO
 from foundation.utils.tool_utils import DateTimeEncoder
+from foundation.models.silicon_flow import SiliconFlowAPI
+from foundation.rag.vector.pg_vector import PGVectorDB
 
 
 @test_router.post("/generate/chat", response_model=TestForm)
@@ -594,3 +596,134 @@ async def test_mysql_add(
     except Exception as err:
         handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
         return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+
+
+
+##################【RAG 相关测试】##############################################
+@test_router.post("/embedding", response_model=TestForm)
+async def embedding_test_endpoint(
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+        embedding模型测试
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        print(trace_id)
+        # 从字典中获取input
+        input_query = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        task_prompt_info = {"task_prompt": ""}
+        text = input_query
+         # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
+        from foundation.models.silicon_flow import SiliconFlowAPI
+        base_api_platform = SiliconFlowAPI()
+        embedding = base_api_platform.get_embeddings([text])[0]
+        embed_dim = len(embedding)
+        server_logger.info(trace_id=trace_id, msg=f"【result】: {embed_dim}")
+
+        output = f"embed_dim={embed_dim},embedding:{embedding}"
+        #output = test_generate_model_client.get_model_data_governance_invoke(trace_id , task_prompt_info, input_query, context)
+        # 直接执行
+        #server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="embedding")
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+
+
+
+
+@test_router.post("/bfp/search", response_model=TestForm)
+async def bfp_search_endpoint(
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+        编制依据向量检索
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        print(trace_id)
+        # 从字典中获取input
+        input_query = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        task_prompt_info = {"task_prompt": ""}
+        top_k = int(session_id)
+        
+        output = None
+        # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
+        client = SiliconFlowAPI()
+        # 抽象测试
+        pg_vector_db = PGVectorDB(base_api_platform=client)
+        output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
+
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+    
+
+
+
+@test_router.post("/bfp/search/rerank", response_model=TestForm)
+async def bfp_search_endpoint(
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+        编制依据文档检索和重排序
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        print(trace_id)
+        # 从字典中获取input
+        input_query = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        task_prompt_info = {"task_prompt": ""}
+        top_k = int(session_id)
+        
+        output = None
+        # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
+        client = SiliconFlowAPI()
+        # 抽象测试
+        pg_vector_db = PGVectorDB(base_api_platform=client)
+        output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
+        # 重排序处理
+        content_list = [doc["text_content"] for doc in output]
+        output = client.rerank(input_query=input_query, documents=content_list , top_n=top_k)
+
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))