pg_vector.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import psycopg2
  2. from psycopg2.extras import RealDictCursor
  3. import numpy as np
  4. #from sentence_transformers import SentenceTransformer
  5. import json
  6. from typing import List, Dict, Any
  7. from foundation.base.config import config_handler
  8. from foundation.logger.loggering import server_logger as logger
  9. from foundation.rag.vector.base_vector import BaseVectorDB
  10. from foundation.models.base_online_platform import BaseApiPlatform
  11. class PGVectorDB(BaseVectorDB):
  12. def __init__(self , base_api_platform :BaseApiPlatform):
  13. """
  14. 初始化 pgvector 连接
  15. """
  16. self.connection_params = {
  17. 'host': config_handler.get('pgvector', 'PGVECTOR_HOST', 'localhost'),
  18. 'port': int(config_handler.get('pgvector', 'PGVECTOR_PORT', '5432')),
  19. 'database': config_handler.get('pgvector', 'PGVECTOR_DB', 'postgres'),
  20. 'user': config_handler.get('pgvector', 'PGVECTOR_USER', 'postgres'),
  21. 'password': config_handler.get('pgvector', 'PGVECTOR_PASSWORD', 'postgres')
  22. }
  23. self.base_api_platform = base_api_platform
  24. def get_connection(self):
  25. """获取数据库连接"""
  26. #logger.info(f"Connecting to PostgreSQL...{self.connection_params}")
  27. conn = psycopg2.connect(**self.connection_params)
  28. # 启用 pgvector 扩展
  29. with conn.cursor() as cur:
  30. cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
  31. conn.commit()
  32. return conn
  33. def create_table(self, table_name: str, vector_dim: int = 384):
  34. """
  35. 创建向量表
  36. """
  37. conn = self.get_connection()
  38. try:
  39. with conn.cursor() as cur:
  40. # 创建表
  41. create_table_sql = f"""
  42. CREATE TABLE IF NOT EXISTS {table_name} (
  43. id SERIAL PRIMARY KEY,
  44. text_content TEXT,
  45. embedding vector({vector_dim}),
  46. metadata JSONB DEFAULT '{{}}'::jsonb,
  47. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  48. );
  49. -- 创建向量相似度索引
  50. CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding
  51. ON {table_name} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
  52. """
  53. cur.execute(create_table_sql)
  54. conn.commit()
  55. print(f"Table {table_name} created successfully!")
  56. except Exception as e:
  57. logger.error(f"Error creating table: {e}")
  58. conn.rollback()
  59. finally:
  60. conn.close()
  61. def document_standard(self, documents: List[Dict[str, Any]]):
  62. """
  63. 对文档进行结果标准处理
  64. """
  65. result = []
  66. for doc in documents:
  67. tmp = {}
  68. tmp['content'] = doc.page_content
  69. tmp['metadata'] = doc.metadata if doc.metadata else {}
  70. result.append(tmp)
  71. return result
  72. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  73. """
  74. 插入单个文本及其向量
  75. """
  76. table_name = param.get('table_name')
  77. text = document.get('content')
  78. metadata = document.get('metadata')
  79. conn = self.get_connection()
  80. try:
  81. with conn.cursor() as cur:
  82. embedding = self.text_to_vector(text)
  83. metadata = metadata or {}
  84. insert_sql = f"""
  85. INSERT INTO {table_name} (text_content, embedding, metadata)
  86. VALUES (%s, %s, %s)
  87. RETURNING id;
  88. """
  89. cur.execute(insert_sql, (text, embedding, json.dumps(metadata)))
  90. inserted_id = cur.fetchone()[0]
  91. conn.commit()
  92. print(f"Text inserted with ID: {inserted_id}")
  93. return inserted_id
  94. except Exception as e:
  95. print(f"Error inserting text: {e}")
  96. conn.rollback()
  97. return None
  98. finally:
  99. conn.close()
  100. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  101. """
  102. 批量插入文本
  103. texts: [{'text': '...', 'metadata': {...}}, ...]
  104. """
  105. table_name = param.get('table_name')
  106. conn = self.get_connection()
  107. try:
  108. with conn.cursor() as cur:
  109. # 准备数据
  110. data_to_insert = []
  111. for item in documents:
  112. text = item['content']
  113. metadata = item.get('metadata', {})
  114. embedding = self.text_to_vector(text)
  115. data_to_insert.append((text, embedding, json.dumps(metadata)))
  116. # 批量插入
  117. insert_sql = f"""
  118. INSERT INTO {table_name} (text_content, embedding, metadata)
  119. VALUES (%s, %s, %s)
  120. """
  121. cur.executemany(insert_sql, data_to_insert)
  122. conn.commit()
  123. logger.info(f"Batch inserted {len(data_to_insert)} records")
  124. except Exception as e:
  125. logger.error(f"Error batch inserting: {e}")
  126. conn.rollback()
  127. finally:
  128. conn.close()
  129. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  130. top_k=5, filters: Dict[str, Any] = None):
  131. """
  132. 搜索相似文本
  133. search_similar 使用距离度量(越小越相似)
  134. """
  135. table_name = param.get('table_name')
  136. conn = self.get_connection()
  137. try:
  138. with conn.cursor(cursor_factory=RealDictCursor) as cur:
  139. query_embedding = self.text_to_vector(query_text)
  140. search_sql = f"""
  141. SELECT id, text_content, metadata,
  142. embedding <=> %s::vector AS distance
  143. FROM {table_name}
  144. ORDER BY embedding <=> %s::vector
  145. LIMIT %s;
  146. """
  147. cur.execute(search_sql, (query_embedding, query_embedding, top_k))
  148. results = cur.fetchall()
  149. return results
  150. except Exception as e:
  151. logger.error(f"Error searching: {e}")
  152. return []
  153. finally:
  154. conn.close()
  155. def retriever(self, param: Dict[str, Any], query_text: str , min_score=0.1 ,
  156. top_k=10, filters: Dict[str, Any] = None):
  157. """
  158. 使用余弦相似度搜索相似文本
  159. """
  160. table_name = param.get('table_name')
  161. conn = self.get_connection()
  162. try:
  163. with conn.cursor(cursor_factory=RealDictCursor) as cur:
  164. query_embedding = self.text_to_vector(query_text)
  165. search_sql = f"""
  166. SELECT id, text_content, metadata,
  167. 1 - (embedding <=> %s::vector) AS cosine_similarity
  168. FROM {table_name}
  169. WHERE 1 - (embedding <=> %s::vector) > %s
  170. ORDER BY 1 - (embedding <=> %s::vector) DESC
  171. LIMIT %s;
  172. """
  173. cur.execute(search_sql, (query_embedding, query_embedding, min_score, query_embedding, top_k))
  174. results = cur.fetchall()
  175. # 打印结果
  176. self.result_logger_info(query_text , results)
  177. return results
  178. except Exception as e:
  179. logger.error(f"Error searching with cosine similarity: {e}")
  180. return []
  181. finally:
  182. conn.close()
  183. def result_logger_info(self , query, result_docs_cos):
  184. """
  185. 记录搜索结果
  186. """
  187. logger.info(f"\n {'=' * 50}")
  188. # 使用余弦相似度搜索
  189. logger.info(f"\nSimilar documents with cosine similarity,query:{query},result_count: {len(result_docs_cos)}:")
  190. for doc in result_docs_cos:
  191. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")
  192. def db_test(self , query_text: str):
  193. """
  194. 测试数据库连接和操作
  195. """
  196. table_name = 'test_documents'
  197. # 创建表
  198. self.create_table(table_name, vector_dim=768)
  199. # 插入单个文本
  200. sample_text = "这是一个关于人工智能的文档。"
  201. #self.insert_text(table_name, sample_text, {'category': 'AI', 'source': 'example'})
  202. # 批量插入文本
  203. sample_texts = [
  204. {
  205. 'text': '机器学习是人工智能的一个重要分支。',
  206. 'metadata': {'category': 'ML', 'author': 'John'}
  207. },
  208. {
  209. 'text': '深度学习在图像识别领域取得了显著成果。',
  210. 'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
  211. },
  212. {
  213. 'text': '自然语言处理技术在聊天机器人中得到广泛应用。',
  214. 'metadata': {'category': 'NLP', 'author': 'Bob'}
  215. }
  216. ]
  217. #self.batch_insert_texts(table_name, sample_texts)
  218. logger.info(f"\n {'=' * 50}")
  219. # 搜索相似文本
  220. #query = "人工智能相关的技术"
  221. query = query_text
  222. logger.info(f"\n query={query}")
  223. similar_docs = self.search_similar(table_name, query, top_k=3)
  224. logger.info(f"Similar documents found {len(similar_docs)}:")
  225. for doc in similar_docs:
  226. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {1 - doc['distance']:.3f}")
  227. logger.info(f"\n {'=' * 50}")
  228. # 使用余弦相似度搜索
  229. similar_docs_cos = self.search_by_cosine_similarity(table_name, query, top_k=3)
  230. logger.info(f"\nSimilar documents with cosine similarity {len(similar_docs_cos)}:")
  231. for doc in similar_docs_cos:
  232. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")