| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- import psycopg2
- from psycopg2.extras import RealDictCursor
- import numpy as np
- #from sentence_transformers import SentenceTransformer
- import json
- from typing import List, Dict, Any
- from foundation.infrastructure.config.config import config_handler
- from foundation.observability.logger.loggering import review_logger as logger
- from foundation.database.base.vector.base_vector import BaseVectorDB
- class PGVectorDB(BaseVectorDB):
- def __init__(self):
- """
- 初始化 pgvector 连接
- """
- self.connection_params = {
- 'host': config_handler.get('pgvector', 'PGVECTOR_HOST', 'localhost'),
- 'port': int(config_handler.get('pgvector', 'PGVECTOR_PORT', '5432')),
- 'database': config_handler.get('pgvector', 'PGVECTOR_DB', 'postgres'),
- 'user': config_handler.get('pgvector', 'PGVECTOR_USER', 'postgres'),
- 'password': config_handler.get('pgvector', 'PGVECTOR_PASSWORD', 'postgres')
- }
-
-
- def get_connection(self):
- """获取数据库连接"""
- #logger.info(f"Connecting to PostgreSQL...{self.connection_params}")
- conn = psycopg2.connect(**self.connection_params)
- # 启用 pgvector 扩展
- with conn.cursor() as cur:
- cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
- conn.commit()
- return conn
-
- def create_table(self, table_name: str, vector_dim: int = 384):
- """
- 创建向量表
- """
- conn = self.get_connection()
- try:
- with conn.cursor() as cur:
- # 创建表
- create_table_sql = f"""
- CREATE TABLE IF NOT EXISTS {table_name} (
- id SERIAL PRIMARY KEY,
- text_content TEXT,
- embedding vector({vector_dim}),
- metadata JSONB DEFAULT '{{}}'::jsonb,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- );
-
- -- 创建向量相似度索引
- CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding
- ON {table_name} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
- """
- cur.execute(create_table_sql)
- conn.commit()
- print(f"Table {table_name} created successfully!")
- except Exception as e:
- logger.error(f"Error creating table: {e}")
- conn.rollback()
- finally:
- conn.close()
-
- def document_standard(self, documents: List[Dict[str, Any]]):
- """
- 对文档进行结果标准处理
- """
- result = []
- for doc in documents:
- tmp = {}
- tmp['content'] = doc.page_content
- tmp['metadata'] = doc.metadata if doc.metadata else {}
- result.append(tmp)
- return result
- def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
- """
- 插入单个文本及其向量
- """
- table_name = param.get('table_name')
- text = document.get('content')
- metadata = document.get('metadata')
- conn = self.get_connection()
- try:
- with conn.cursor() as cur:
- embedding = self.text_to_vector(text)
- metadata = metadata or {}
-
- insert_sql = f"""
- INSERT INTO {table_name} (text_content, embedding, metadata)
- VALUES (%s, %s, %s)
- RETURNING id;
- """
- cur.execute(insert_sql, (text, embedding, json.dumps(metadata)))
- inserted_id = cur.fetchone()[0]
- conn.commit()
- print(f"Text inserted with ID: {inserted_id}")
- return inserted_id
- except Exception as e:
- print(f"Error inserting text: {e}")
- conn.rollback()
- return None
- finally:
- conn.close()
-
- def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
- """
- 批量插入文本
- texts: [{'text': '...', 'metadata': {...}}, ...]
- """
- table_name = param.get('table_name')
- conn = self.get_connection()
- try:
- with conn.cursor() as cur:
- # 准备数据
- data_to_insert = []
- for item in documents:
- text = item['content']
- metadata = item.get('metadata', {})
- embedding = self.text_to_vector(text)
- data_to_insert.append((text, embedding, json.dumps(metadata)))
-
- # 批量插入
- insert_sql = f"""
- INSERT INTO {table_name} (text_content, embedding, metadata)
- VALUES (%s, %s, %s)
- """
- cur.executemany(insert_sql, data_to_insert)
- conn.commit()
- logger.info(f"Batch inserted {len(data_to_insert)} records")
- except Exception as e:
- logger.error(f"Error batch inserting: {e}")
- conn.rollback()
- finally:
- conn.close()
-
- def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
- top_k=5, filters: Dict[str, Any] = None):
- """
- 搜索相似文本
- search_similar 使用距离度量(越小越相似)
-
- """
- table_name = param.get('table_name')
- conn = self.get_connection()
- try:
- with conn.cursor(cursor_factory=RealDictCursor) as cur:
- query_embedding = self.text_to_vector(query_text)
-
- search_sql = f"""
- SELECT id, text_content, metadata,
- embedding <=> %s::vector AS distance
- FROM {table_name}
- ORDER BY embedding <=> %s::vector
- LIMIT %s;
- """
- cur.execute(search_sql, (query_embedding, query_embedding, top_k))
- results = cur.fetchall()
-
- return results
- except Exception as e:
- logger.error(f"Error searching: {e}")
- return []
- finally:
- conn.close()
-
- def retriever(self, param: Dict[str, Any], query_text: str , min_score=0.1 ,
- top_k=10, filters: Dict[str, Any] = None):
- """
- 使用余弦相似度搜索相似文本
- """
- table_name = param.get('table_name')
- conn = self.get_connection()
- try:
- with conn.cursor(cursor_factory=RealDictCursor) as cur:
- query_embedding = self.text_to_vector(query_text)
-
- search_sql = f"""
- SELECT id, text_content, metadata,
- 1 - (embedding <=> %s::vector) AS cosine_similarity
- FROM {table_name}
- WHERE 1 - (embedding <=> %s::vector) > %s
- ORDER BY 1 - (embedding <=> %s::vector) DESC
- LIMIT %s;
- """
- cur.execute(search_sql, (query_embedding, query_embedding, min_score, query_embedding, top_k))
- results = cur.fetchall()
- # 打印结果
- self.result_logger_info(query_text , results)
- return results
- except Exception as e:
- logger.error(f"Error searching with cosine similarity: {e}")
- return []
- finally:
- conn.close()
-
- def result_logger_info(self , query, result_docs_cos):
- """
- 记录搜索结果
- """
- logger.info(f"\n {'=' * 50}")
- # 使用余弦相似度搜索
- logger.info(f"\nSimilar documents with cosine similarity,query:{query},result_count: {len(result_docs_cos)}:")
- for doc in result_docs_cos:
- logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")
- def db_test(self , query_text: str):
- """
- 测试数据库连接和操作
- """
- table_name = 'test_documents'
- # 创建表
- self.create_table(table_name, vector_dim=768)
-
- # 插入单个文本
- sample_text = "这是一个关于人工智能的文档。"
- #self.insert_text(table_name, sample_text, {'category': 'AI', 'source': 'example'})
-
- # 批量插入文本
- sample_texts = [
- {
- 'text': '机器学习是人工智能的一个重要分支。',
- 'metadata': {'category': 'ML', 'author': 'John'}
- },
- {
- 'text': '深度学习在图像识别领域取得了显著成果。',
- 'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
- },
- {
- 'text': '自然语言处理技术在聊天机器人中得到广泛应用。',
- 'metadata': {'category': 'NLP', 'author': 'Bob'}
- }
- ]
-
- #self.batch_insert_texts(table_name, sample_texts)
-
- logger.info(f"\n {'=' * 50}")
- # 搜索相似文本
- #query = "人工智能相关的技术"
- query = query_text
- logger.info(f"\n query={query}")
- similar_docs = self.search_similar(table_name, query, top_k=3)
- logger.info(f"Similar documents found {len(similar_docs)}:")
- for doc in similar_docs:
- logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {1 - doc['distance']:.3f}")
-
- logger.info(f"\n {'=' * 50}")
- # 使用余弦相似度搜索
- similar_docs_cos = self.search_by_cosine_similarity(table_name, query, top_k=3)
-
- logger.info(f"\nSimilar documents with cosine similarity {len(similar_docs_cos)}:")
- for doc in similar_docs_cos:
- logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")
|