pg_vector.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import psycopg2
  2. from psycopg2.extras import RealDictCursor
  3. import numpy as np
  4. #from sentence_transformers import SentenceTransformer
  5. import json
  6. from typing import List, Dict, Any
  7. from foundation.infrastructure.config.config import config_handler
  8. from foundation.observability.logger.loggering import review_logger as logger
  9. from foundation.database.base.vector.base_vector import BaseVectorDB
  10. class PGVectorDB(BaseVectorDB):
  11. def __init__(self):
  12. """
  13. 初始化 pgvector 连接
  14. """
  15. self.connection_params = {
  16. 'host': config_handler.get('pgvector', 'PGVECTOR_HOST', 'localhost'),
  17. 'port': int(config_handler.get('pgvector', 'PGVECTOR_PORT', '5432')),
  18. 'database': config_handler.get('pgvector', 'PGVECTOR_DB', 'postgres'),
  19. 'user': config_handler.get('pgvector', 'PGVECTOR_USER', 'postgres'),
  20. 'password': config_handler.get('pgvector', 'PGVECTOR_PASSWORD', 'postgres')
  21. }
  22. def get_connection(self):
  23. """获取数据库连接"""
  24. #logger.info(f"Connecting to PostgreSQL...{self.connection_params}")
  25. conn = psycopg2.connect(**self.connection_params)
  26. # 启用 pgvector 扩展
  27. with conn.cursor() as cur:
  28. cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
  29. conn.commit()
  30. return conn
  31. def create_table(self, table_name: str, vector_dim: int = 384):
  32. """
  33. 创建向量表
  34. """
  35. conn = self.get_connection()
  36. try:
  37. with conn.cursor() as cur:
  38. # 创建表
  39. create_table_sql = f"""
  40. CREATE TABLE IF NOT EXISTS {table_name} (
  41. id SERIAL PRIMARY KEY,
  42. text_content TEXT,
  43. embedding vector({vector_dim}),
  44. metadata JSONB DEFAULT '{{}}'::jsonb,
  45. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  46. );
  47. -- 创建向量相似度索引
  48. CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding
  49. ON {table_name} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
  50. """
  51. cur.execute(create_table_sql)
  52. conn.commit()
  53. print(f"Table {table_name} created successfully!")
  54. except Exception as e:
  55. logger.error(f"Error creating table: {e}")
  56. conn.rollback()
  57. finally:
  58. conn.close()
  59. def document_standard(self, documents: List[Dict[str, Any]]):
  60. """
  61. 对文档进行结果标准处理
  62. """
  63. result = []
  64. for doc in documents:
  65. tmp = {}
  66. tmp['content'] = doc.page_content
  67. tmp['metadata'] = doc.metadata if doc.metadata else {}
  68. result.append(tmp)
  69. return result
  70. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  71. """
  72. 插入单个文本及其向量
  73. """
  74. table_name = param.get('table_name')
  75. text = document.get('content')
  76. metadata = document.get('metadata')
  77. conn = self.get_connection()
  78. try:
  79. with conn.cursor() as cur:
  80. embedding = self.text_to_vector(text)
  81. metadata = metadata or {}
  82. insert_sql = f"""
  83. INSERT INTO {table_name} (text_content, embedding, metadata)
  84. VALUES (%s, %s, %s)
  85. RETURNING id;
  86. """
  87. cur.execute(insert_sql, (text, embedding, json.dumps(metadata)))
  88. inserted_id = cur.fetchone()[0]
  89. conn.commit()
  90. print(f"Text inserted with ID: {inserted_id}")
  91. return inserted_id
  92. except Exception as e:
  93. print(f"Error inserting text: {e}")
  94. conn.rollback()
  95. return None
  96. finally:
  97. conn.close()
  98. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  99. """
  100. 批量插入文本
  101. texts: [{'text': '...', 'metadata': {...}}, ...]
  102. """
  103. table_name = param.get('table_name')
  104. conn = self.get_connection()
  105. try:
  106. with conn.cursor() as cur:
  107. # 准备数据
  108. data_to_insert = []
  109. for item in documents:
  110. text = item['content']
  111. metadata = item.get('metadata', {})
  112. embedding = self.text_to_vector(text)
  113. data_to_insert.append((text, embedding, json.dumps(metadata)))
  114. # 批量插入
  115. insert_sql = f"""
  116. INSERT INTO {table_name} (text_content, embedding, metadata)
  117. VALUES (%s, %s, %s)
  118. """
  119. cur.executemany(insert_sql, data_to_insert)
  120. conn.commit()
  121. logger.info(f"Batch inserted {len(data_to_insert)} records")
  122. except Exception as e:
  123. logger.error(f"Error batch inserting: {e}")
  124. conn.rollback()
  125. finally:
  126. conn.close()
  127. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  128. top_k=5, filters: Dict[str, Any] = None):
  129. """
  130. 搜索相似文本
  131. search_similar 使用距离度量(越小越相似)
  132. """
  133. table_name = param.get('table_name')
  134. conn = self.get_connection()
  135. try:
  136. with conn.cursor(cursor_factory=RealDictCursor) as cur:
  137. query_embedding = self.text_to_vector(query_text)
  138. search_sql = f"""
  139. SELECT id, text_content, metadata,
  140. embedding <=> %s::vector AS distance
  141. FROM {table_name}
  142. ORDER BY embedding <=> %s::vector
  143. LIMIT %s;
  144. """
  145. cur.execute(search_sql, (query_embedding, query_embedding, top_k))
  146. results = cur.fetchall()
  147. return results
  148. except Exception as e:
  149. logger.error(f"Error searching: {e}")
  150. return []
  151. finally:
  152. conn.close()
  153. def retriever(self, param: Dict[str, Any], query_text: str , min_score=0.1 ,
  154. top_k=10, filters: Dict[str, Any] = None):
  155. """
  156. 使用余弦相似度搜索相似文本
  157. """
  158. table_name = param.get('table_name')
  159. conn = self.get_connection()
  160. try:
  161. with conn.cursor(cursor_factory=RealDictCursor) as cur:
  162. query_embedding = self.text_to_vector(query_text)
  163. search_sql = f"""
  164. SELECT id, text_content, metadata,
  165. 1 - (embedding <=> %s::vector) AS cosine_similarity
  166. FROM {table_name}
  167. WHERE 1 - (embedding <=> %s::vector) > %s
  168. ORDER BY 1 - (embedding <=> %s::vector) DESC
  169. LIMIT %s;
  170. """
  171. cur.execute(search_sql, (query_embedding, query_embedding, min_score, query_embedding, top_k))
  172. results = cur.fetchall()
  173. # 打印结果
  174. self.result_logger_info(query_text , results)
  175. return results
  176. except Exception as e:
  177. logger.error(f"Error searching with cosine similarity: {e}")
  178. return []
  179. finally:
  180. conn.close()
  181. def result_logger_info(self , query, result_docs_cos):
  182. """
  183. 记录搜索结果
  184. """
  185. logger.info(f"\n {'=' * 50}")
  186. # 使用余弦相似度搜索
  187. logger.info(f"\nSimilar documents with cosine similarity,query:{query},result_count: {len(result_docs_cos)}:")
  188. for doc in result_docs_cos:
  189. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")
  190. def db_test(self , query_text: str):
  191. """
  192. 测试数据库连接和操作
  193. """
  194. table_name = 'test_documents'
  195. # 创建表
  196. self.create_table(table_name, vector_dim=768)
  197. # 插入单个文本
  198. sample_text = "这是一个关于人工智能的文档。"
  199. #self.insert_text(table_name, sample_text, {'category': 'AI', 'source': 'example'})
  200. # 批量插入文本
  201. sample_texts = [
  202. {
  203. 'text': '机器学习是人工智能的一个重要分支。',
  204. 'metadata': {'category': 'ML', 'author': 'John'}
  205. },
  206. {
  207. 'text': '深度学习在图像识别领域取得了显著成果。',
  208. 'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
  209. },
  210. {
  211. 'text': '自然语言处理技术在聊天机器人中得到广泛应用。',
  212. 'metadata': {'category': 'NLP', 'author': 'Bob'}
  213. }
  214. ]
  215. #self.batch_insert_texts(table_name, sample_texts)
  216. logger.info(f"\n {'=' * 50}")
  217. # 搜索相似文本
  218. #query = "人工智能相关的技术"
  219. query = query_text
  220. logger.info(f"\n query={query}")
  221. similar_docs = self.search_similar(table_name, query, top_k=3)
  222. logger.info(f"Similar documents found {len(similar_docs)}:")
  223. for doc in similar_docs:
  224. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {1 - doc['distance']:.3f}")
  225. logger.info(f"\n {'=' * 50}")
  226. # 使用余弦相似度搜索
  227. similar_docs_cos = self.search_by_cosine_similarity(table_name, query, top_k=3)
  228. logger.info(f"\nSimilar documents with cosine similarity {len(similar_docs_cos)}:")
  229. for doc in similar_docs_cos:
  230. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")