pg_vector.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. 
  2. import psycopg2
  3. from psycopg2.extras import RealDictCursor
  4. import numpy as np
  5. #from sentence_transformers import SentenceTransformer
  6. import json
  7. from typing import List, Dict, Any
  8. from foundation.infrastructure.config.config import config_handler
  9. from foundation.observability.logger.loggering import write_logger as logger
  10. from foundation.database.base.vector.base_vector import BaseVectorDB
  11. class PGVectorDB(BaseVectorDB):
  12. def __init__(self):
  13. """
  14. 初始化 pgvector 连接
  15. """
  16. self.connection_params = {
  17. 'host': config_handler.get('pgvector', 'PGVECTOR_HOST', 'localhost'),
  18. 'port': int(config_handler.get('pgvector', 'PGVECTOR_PORT', '5432')),
  19. 'database': config_handler.get('pgvector', 'PGVECTOR_DB', 'postgres'),
  20. 'user': config_handler.get('pgvector', 'PGVECTOR_USER', 'postgres'),
  21. 'password': config_handler.get('pgvector', 'PGVECTOR_PASSWORD', 'postgres')
  22. }
  23. def get_connection(self):
  24. """获取数据库连接"""
  25. #logger.info(f"Connecting to PostgreSQL...{self.connection_params}")
  26. conn = psycopg2.connect(**self.connection_params)
  27. # 启用 pgvector 扩展
  28. with conn.cursor() as cur:
  29. cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
  30. conn.commit()
  31. return conn
  32. def create_table(self, table_name: str, vector_dim: int = 384):
  33. """
  34. 创建向量表
  35. """
  36. conn = self.get_connection()
  37. try:
  38. with conn.cursor() as cur:
  39. # 创建表
  40. create_table_sql = f"""
  41. CREATE TABLE IF NOT EXISTS {table_name} (
  42. id SERIAL PRIMARY KEY,
  43. text_content TEXT,
  44. embedding vector({vector_dim}),
  45. metadata JSONB DEFAULT '{{}}'::jsonb,
  46. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  47. );
  48. -- 创建向量相似度索引
  49. CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding
  50. ON {table_name} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
  51. """
  52. cur.execute(create_table_sql)
  53. conn.commit()
  54. print(f"Table {table_name} created successfully!")
  55. except Exception as e:
  56. logger.error(f"Error creating table: {e}")
  57. conn.rollback()
  58. finally:
  59. conn.close()
  60. def document_standard(self, documents: List[Dict[str, Any]]):
  61. """
  62. 对文档进行结果标准处理
  63. """
  64. result = []
  65. for doc in documents:
  66. tmp = {}
  67. tmp['content'] = doc.page_content
  68. tmp['metadata'] = doc.metadata if doc.metadata else {}
  69. result.append(tmp)
  70. return result
  71. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  72. """
  73. 插入单个文本及其向量
  74. """
  75. table_name = param.get('table_name')
  76. text = document.get('content')
  77. metadata = document.get('metadata')
  78. conn = self.get_connection()
  79. try:
  80. with conn.cursor() as cur:
  81. embedding = self.text_to_vector(text)
  82. metadata = metadata or {}
  83. insert_sql = f"""
  84. INSERT INTO {table_name} (text_content, embedding, metadata)
  85. VALUES (%s, %s, %s)
  86. RETURNING id;
  87. """
  88. cur.execute(insert_sql, (text, embedding, json.dumps(metadata)))
  89. inserted_id = cur.fetchone()[0]
  90. conn.commit()
  91. print(f"Text inserted with ID: {inserted_id}")
  92. return inserted_id
  93. except Exception as e:
  94. print(f"Error inserting text: {e}")
  95. conn.rollback()
  96. return None
  97. finally:
  98. conn.close()
  99. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  100. """
  101. 批量插入文本
  102. texts: [{'text': '...', 'metadata': {...}}, ...]
  103. """
  104. table_name = param.get('table_name')
  105. conn = self.get_connection()
  106. try:
  107. with conn.cursor() as cur:
  108. # 准备数据
  109. data_to_insert = []
  110. for item in documents:
  111. text = item['content']
  112. metadata = item.get('metadata', {})
  113. embedding = self.text_to_vector(text)
  114. data_to_insert.append((text, embedding, json.dumps(metadata)))
  115. # 批量插入
  116. insert_sql = f"""
  117. INSERT INTO {table_name} (text_content, embedding, metadata)
  118. VALUES (%s, %s, %s)
  119. """
  120. cur.executemany(insert_sql, data_to_insert)
  121. conn.commit()
  122. logger.info(f"Batch inserted {len(data_to_insert)} records")
  123. except Exception as e:
  124. logger.error(f"Error batch inserting: {e}")
  125. conn.rollback()
  126. finally:
  127. conn.close()
  128. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  129. top_k=5, filters: Dict[str, Any] = None):
  130. """
  131. 搜索相似文本
  132. search_similar 使用距离度量(越小越相似)
  133. """
  134. table_name = param.get('table_name')
  135. conn = self.get_connection()
  136. try:
  137. with conn.cursor(cursor_factory=RealDictCursor) as cur:
  138. query_embedding = self.text_to_vector(query_text)
  139. search_sql = f"""
  140. SELECT id, text_content, metadata,
  141. embedding <=> %s::vector AS distance
  142. FROM {table_name}
  143. ORDER BY embedding <=> %s::vector
  144. LIMIT %s;
  145. """
  146. cur.execute(search_sql, (query_embedding, query_embedding, top_k))
  147. results = cur.fetchall()
  148. return results
  149. except Exception as e:
  150. logger.error(f"Error searching: {e}")
  151. return []
  152. finally:
  153. conn.close()
  154. def retriever(self, param: Dict[str, Any], query_text: str , min_score=0.1 ,
  155. top_k=10, filters: Dict[str, Any] = None):
  156. """
  157. 使用余弦相似度搜索相似文本
  158. """
  159. table_name = param.get('table_name')
  160. conn = self.get_connection()
  161. try:
  162. with conn.cursor(cursor_factory=RealDictCursor) as cur:
  163. query_embedding = self.text_to_vector(query_text)
  164. search_sql = f"""
  165. SELECT id, text_content, metadata,
  166. 1 - (embedding <=> %s::vector) AS cosine_similarity
  167. FROM {table_name}
  168. WHERE 1 - (embedding <=> %s::vector) > %s
  169. ORDER BY 1 - (embedding <=> %s::vector) DESC
  170. LIMIT %s;
  171. """
  172. cur.execute(search_sql, (query_embedding, query_embedding, min_score, query_embedding, top_k))
  173. results = cur.fetchall()
  174. # 打印结果
  175. self.result_logger_info(query_text , results)
  176. return results
  177. except Exception as e:
  178. logger.error(f"Error searching with cosine similarity: {e}")
  179. return []
  180. finally:
  181. conn.close()
  182. def result_logger_info(self , query, result_docs_cos):
  183. """
  184. 记录搜索结果
  185. """
  186. logger.info(f"\n {'=' * 50}")
  187. # 使用余弦相似度搜索
  188. logger.info(f"\nSimilar documents with cosine similarity,query:{query},result_count: {len(result_docs_cos)}:")
  189. for doc in result_docs_cos:
  190. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")
  191. def db_test(self , query_text: str):
  192. """
  193. 测试数据库连接和操作
  194. """
  195. table_name = 'test_documents'
  196. # 创建表
  197. self.create_table(table_name, vector_dim=768)
  198. # 插入单个文本
  199. sample_text = "这是一个关于人工智能的文档。"
  200. #self.insert_text(table_name, sample_text, {'category': 'AI', 'source': 'example'})
  201. # 批量插入文本
  202. sample_texts = [
  203. {
  204. 'text': '机器学习是人工智能的一个重要分支。',
  205. 'metadata': {'category': 'ML', 'author': 'John'}
  206. },
  207. {
  208. 'text': '深度学习在图像识别领域取得了显著成果。',
  209. 'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
  210. },
  211. {
  212. 'text': '自然语言处理技术在聊天机器人中得到广泛应用。',
  213. 'metadata': {'category': 'NLP', 'author': 'Bob'}
  214. }
  215. ]
  216. #self.batch_insert_texts(table_name, sample_texts)
  217. logger.info(f"\n {'=' * 50}")
  218. # 搜索相似文本
  219. #query = "人工智能相关的技术"
  220. query = query_text
  221. logger.info(f"\n query={query}")
  222. similar_docs = self.search_similar(table_name, query, top_k=3)
  223. logger.info(f"Similar documents found {len(similar_docs)}:")
  224. for doc in similar_docs:
  225. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {1 - doc['distance']:.3f}")
  226. logger.info(f"\n {'=' * 50}")
  227. # 使用余弦相似度搜索
  228. similar_docs_cos = self.search_by_cosine_similarity(table_name, query, top_k=3)
  229. logger.info(f"\nSimilar documents with cosine similarity {len(similar_docs_cos)}:")
  230. for doc in similar_docs_cos:
  231. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")