import psycopg2 from psycopg2.extras import RealDictCursor import numpy as np #from sentence_transformers import SentenceTransformer import json from typing import List, Dict, Any from foundation.base.config import config_handler from foundation.logger.loggering import server_logger as logger from foundation.rag.vector.base_vector import BaseVectorDB from foundation.models.base_online_platform import BaseApiPlatform class PGVectorDB(BaseVectorDB): def __init__(self , base_api_platform :BaseApiPlatform): """ 初始化 pgvector 连接 """ self.connection_params = { 'host': config_handler.get('pgvector', 'PGVECTOR_HOST', 'localhost'), 'port': int(config_handler.get('pgvector', 'PGVECTOR_PORT', '5432')), 'database': config_handler.get('pgvector', 'PGVECTOR_DB', 'postgres'), 'user': config_handler.get('pgvector', 'PGVECTOR_USER', 'postgres'), 'password': config_handler.get('pgvector', 'PGVECTOR_PASSWORD', 'postgres') } self.base_api_platform = base_api_platform def get_connection(self): """获取数据库连接""" #logger.info(f"Connecting to PostgreSQL...{self.connection_params}") conn = psycopg2.connect(**self.connection_params) # 启用 pgvector 扩展 with conn.cursor() as cur: cur.execute("CREATE EXTENSION IF NOT EXISTS vector;") conn.commit() return conn def create_table(self, table_name: str, vector_dim: int = 384): """ 创建向量表 """ conn = self.get_connection() try: with conn.cursor() as cur: # 创建表 create_table_sql = f""" CREATE TABLE IF NOT EXISTS {table_name} ( id SERIAL PRIMARY KEY, text_content TEXT, embedding vector({vector_dim}), metadata JSONB DEFAULT '{{}}'::jsonb, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); -- 创建向量相似度索引 CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding ON {table_name} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); """ cur.execute(create_table_sql) conn.commit() print(f"Table {table_name} created successfully!") except Exception as e: logger.error(f"Error creating table: {e}") conn.rollback() finally: conn.close() def document_standard(self, documents: List[Dict[str, Any]]): """ 对文档进行结果标准处理 """ result = [] for doc in documents: tmp = {} tmp['content'] = doc.page_content tmp['metadata'] = doc.metadata if doc.metadata else {} result.append(tmp) return result def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]): """ 插入单个文本及其向量 """ table_name = param.get('table_name') text = document.get('content') metadata = document.get('metadata') conn = self.get_connection() try: with conn.cursor() as cur: embedding = self.text_to_vector(text) metadata = metadata or {} insert_sql = f""" INSERT INTO {table_name} (text_content, embedding, metadata) VALUES (%s, %s, %s) RETURNING id; """ cur.execute(insert_sql, (text, embedding, json.dumps(metadata))) inserted_id = cur.fetchone()[0] conn.commit() print(f"Text inserted with ID: {inserted_id}") return inserted_id except Exception as e: print(f"Error inserting text: {e}") conn.rollback() return None finally: conn.close() def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]): """ 批量插入文本 texts: [{'text': '...', 'metadata': {...}}, ...] """ table_name = param.get('table_name') conn = self.get_connection() try: with conn.cursor() as cur: # 准备数据 data_to_insert = [] for item in documents: text = item['content'] metadata = item.get('metadata', {}) embedding = self.text_to_vector(text) data_to_insert.append((text, embedding, json.dumps(metadata))) # 批量插入 insert_sql = f""" INSERT INTO {table_name} (text_content, embedding, metadata) VALUES (%s, %s, %s) """ cur.executemany(insert_sql, data_to_insert) conn.commit() logger.info(f"Batch inserted {len(data_to_insert)} records") except Exception as e: logger.error(f"Error batch inserting: {e}") conn.rollback() finally: conn.close() def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 , top_k=5, filters: Dict[str, Any] = None): """ 搜索相似文本 search_similar 使用距离度量(越小越相似) """ table_name = param.get('table_name') conn = self.get_connection() try: with conn.cursor(cursor_factory=RealDictCursor) as cur: query_embedding = self.text_to_vector(query_text) search_sql = f""" SELECT id, text_content, metadata, embedding <=> %s::vector AS distance FROM {table_name} ORDER BY embedding <=> %s::vector LIMIT %s; """ cur.execute(search_sql, (query_embedding, query_embedding, top_k)) results = cur.fetchall() return results except Exception as e: logger.error(f"Error searching: {e}") return [] finally: conn.close() def retriever(self, param: Dict[str, Any], query_text: str , min_score=0.1 , top_k=10, filters: Dict[str, Any] = None): """ 使用余弦相似度搜索相似文本 """ table_name = param.get('table_name') conn = self.get_connection() try: with conn.cursor(cursor_factory=RealDictCursor) as cur: query_embedding = self.text_to_vector(query_text) search_sql = f""" SELECT id, text_content, metadata, 1 - (embedding <=> %s::vector) AS cosine_similarity FROM {table_name} WHERE 1 - (embedding <=> %s::vector) > %s ORDER BY 1 - (embedding <=> %s::vector) DESC LIMIT %s; """ cur.execute(search_sql, (query_embedding, query_embedding, min_score, query_embedding, top_k)) results = cur.fetchall() # 打印结果 self.result_logger_info(query_text , results) return results except Exception as e: logger.error(f"Error searching with cosine similarity: {e}") return [] finally: conn.close() def result_logger_info(self , query, result_docs_cos): """ 记录搜索结果 """ logger.info(f"\n {'=' * 50}") # 使用余弦相似度搜索 logger.info(f"\nSimilar documents with cosine similarity,query:{query},result_count: {len(result_docs_cos)}:") for doc in result_docs_cos: logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}") def db_test(self , query_text: str): """ 测试数据库连接和操作 """ table_name = 'test_documents' # 创建表 self.create_table(table_name, vector_dim=768) # 插入单个文本 sample_text = "这是一个关于人工智能的文档。" #self.insert_text(table_name, sample_text, {'category': 'AI', 'source': 'example'}) # 批量插入文本 sample_texts = [ { 'text': '机器学习是人工智能的一个重要分支。', 'metadata': {'category': 'ML', 'author': 'John'} }, { 'text': '深度学习在图像识别领域取得了显著成果。', 'metadata': {'category': 'Deep Learning', 'author': 'Jane'} }, { 'text': '自然语言处理技术在聊天机器人中得到广泛应用。', 'metadata': {'category': 'NLP', 'author': 'Bob'} } ] #self.batch_insert_texts(table_name, sample_texts) logger.info(f"\n {'=' * 50}") # 搜索相似文本 #query = "人工智能相关的技术" query = query_text logger.info(f"\n query={query}") similar_docs = self.search_similar(table_name, query, top_k=3) logger.info(f"Similar documents found {len(similar_docs)}:") for doc in similar_docs: logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {1 - doc['distance']:.3f}") logger.info(f"\n {'=' * 50}") # 使用余弦相似度搜索 similar_docs_cos = self.search_by_cosine_similarity(table_name, query, top_k=3) logger.info(f"\nSimilar documents with cosine similarity {len(similar_docs_cos)}:") for doc in similar_docs_cos: logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")