| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478 |
- import time
- from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility, Function
- from pymilvus.client.types import FunctionType
- from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
- # from sentence_transformers import SentenceTransformer
- import numpy as np
- from typing import List, Dict, Any, Optional
- import json
- # 导入 LangChain Milvus 混合搜索相关包
- from langchain_milvus import Milvus, BM25BuiltInFunction
- from langchain_core.documents import Document
- from langchain_core.embeddings import Embeddings
- from foundation.infrastructure.config.config import config_handler
- from foundation.database.base.vector.base_vector import BaseVectorDB
- # 延迟导入logger和model_handler以避免循环依赖
- logger = None
- model_handler = None
- def _get_logger():
- """延迟导入logger以避免循环依赖"""
- global logger
- if logger is None:
- try:
- from foundation.observability.logger.loggering import server_logger
- logger = server_logger
- except ImportError:
- # 如果导入失败,创建一个简单的logger替代品
- import logging
- logger = logging.getLogger(__name__)
- return logger
- def _get_model_handler():
- """延迟导入model_handler以避免循环依赖"""
- global model_handler
- if model_handler is None:
- try:
- from foundation.ai.models.model_handler import model_handler as mh
- model_handler = mh
- except ImportError:
- # 如果导入失败,返回None
- model_handler = None
- return model_handler
- class MilvusVectorManager(BaseVectorDB):
- def __init__(self):
- """
- 初始化 Milvus 连接
- """
- self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
- self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
- self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default')
- self.user = config_handler.get('milvus', 'MILVUS_USER')
- self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
-
- # 初始化文本向量化模型
- mh = _get_model_handler()
- if mh:
- self.emdmodel = mh._get_lq_qwen3_8b_emd()
- else:
- raise ImportError("无法导入model_handler,无法初始化嵌入模型")
- # 连接到 Milvus
- self.connect()
- def text_to_vector(self, text: str) -> List[float]:
- """
- 将文本转换为向量(重写基类方法,直接使用嵌入模型)
- """
- try:
- # 使用已有的嵌入模型
- embedding = self.emdmodel.embed_query(text)
- return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
- except Exception as e:
- _get_logger().error(f"Error converting text to vector: {e}")
- raise
-
- def connect(self):
- """连接到 Milvus 服务器
- ,
- password=self.password
- alias="default",
- """
- try:
- connections.connect(
- alias="default",
- host=self.host,
- port=self.port,
- user=self.user,
- db_name="lq_db"
- )
- _get_logger().info(f"Connected to Milvus at {self.host}:{self.port}")
- except Exception as e:
- _get_logger().error(f"Failed to connect to Milvus: {e}")
- raise
-
- def create_collection(self, collection_name: str, dimension: int = 768,
- description: str = "Vector collection for text embeddings"):
- """
- 创建向量集合
- """
- try:
- # 检查集合是否已存在
- if utility.has_collection(collection_name):
- _get_logger().info(f"Collection {collection_name} already exists")
- utility.drop_collection(collection_name)
- _get_logger().info(f"Collection '{collection_name}' dropped successfully")
-
-
- # 定义字段
- fields = [
- FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
- FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
- FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
- FieldSchema(name="metadata", dtype=DataType.JSON),
- FieldSchema(name="created_at", dtype=DataType.INT64)
- ]
-
- # 创建集合模式
- schema = CollectionSchema(
- fields=fields,
- description=description
- )
-
- # 创建集合
- collection = Collection(
- name=collection_name,
- schema=schema
- )
-
- # 创建索引
- index_params = {
- "index_type": "IVF_FLAT",
- "metric_type": "COSINE",
- "params": {"nlist": 100}
- }
- collection.create_index(field_name="vector", index_params=index_params)
- _get_logger().info(f"Collection {collection_name} created successfully!")
-
- except Exception as e:
- _get_logger().error(f"Error creating collection: {e}")
- raise
-
-
-
- def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
- """
- 插入单个文本及其向量
- """
- try:
- collection_name = param.get('collection_name')
- text = document.get('content')
- metadata = document.get('metadata')
- collection = Collection(collection_name)
- created_at = None
-
- # 转换文本为向量
- embedding = self.text_to_vector(text)
- #_get_logger().info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}")
- #_get_logger().info(f"Text converted to embedding:{embedding}")
- # 准备数据
- data = [
- [embedding], # embedding
- [text], # text_content
- [metadata or {}], # metadata
- [created_at or int(time.time())] # created_at
- ]
- _get_logger().info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}")
-
- # 插入数据
- insert_result = collection.insert(data)
- collection.flush() # 确保数据被写入
-
- _get_logger().info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
- return insert_result.primary_keys[0]
-
- except Exception as e:
- _get_logger().error(f"Error inserting text: {e}")
- return None
-
- def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
- """
- 批量插入文本
- texts: [{'text': '...', 'metadata': {...}}, ...]
- """
- try:
- collection_name = param.get('collection_name')
- collection = Collection(collection_name)
-
- text_contents = []
- embeddings = []
- metadatas = []
- timestamps = []
-
- for item in documents:
- text = item['content']
- metadata = item.get('metadata', {})
-
- # 转换文本为向量
- embedding = self.text_to_vector(text)
-
- text_contents.append(text)
- embeddings.append(embedding)
- metadatas.append(metadata)
- timestamps.append(int(time.time()))
-
-
- # 准备批量数据
- data = [embeddings, text_contents, metadatas, timestamps]
- #_get_logger().info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}")
-
- # 批量插入
- insert_result = collection.insert(data)
- collection.flush() # 确保数据被写入
-
- _get_logger().info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
- return insert_result.primary_keys
-
- except Exception as e:
- _get_logger().error(f"Error batch inserting: {e}")
- return None
-
- def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
- top_k=5, filters: Dict[str, Any] = None):
- """
- 搜索相似文本
- """
- try:
- collection_name = param.get('collection_name')
- collection = Collection(collection_name)
-
- # 加载集合到内存(如果还没有加载)
- collection.load()
-
- # 转换查询文本为向量
- query_embedding = self.text_to_vector(query_text)
-
- # 搜索参数
- search_params = {
- "metric_type": "COSINE",
- "params": {"nprobe": 10}
- }
- # 构建过滤表达式
- filter_expr = self._create_filter(filters)
-
- # 执行搜索
- results = collection.search(
- data=[query_embedding],
- anns_field="vector",
- param=search_params,
- limit=top_k,
- expr=filter_expr,
- output_fields=["text_content", "metadata"]
- )
-
- # 格式化结果
- formatted_results = []
- for hits in results:
- for hit in hits:
- formatted_results.append({
- 'id': hit.id,
- 'text_content': hit.entity.get('text_content'),
- 'metadata': hit.entity.get('metadata'),
- 'distance': hit.distance,
- 'similarity': 1 - hit.distance # 转换为相似度
- })
-
- return formatted_results
-
- except Exception as e:
- _get_logger().error(f"Error searching: {e}")
- return []
-
- def retriever(self, param: Dict[str, Any], query_text: str,
- top_k: int = 5, filters: Dict[str, Any] = None):
- """
- 带过滤条件的相似搜索
- """
- try:
- collection_name = param.get('collection_name')
- collection = Collection(collection_name)
- collection.load()
-
- query_embedding = self.text_to_vector(query_text)
-
- # 构建过滤表达式
- filter_expr = self._create_filter(filters)
-
- search_params = {
- "metric_type": "COSINE",
- "params": {"nprobe": 10}
- }
-
- results = collection.search(
- data=[query_embedding],
- anns_field="vector",
- param=search_params,
- limit=top_k,
- expr=filter_expr,
- output_fields=["text_content", "metadata"]
- )
-
- formatted_results = []
- for hits in results:
- for hit in hits:
- formatted_results.append({
- 'id': hit.id,
- 'text_content': hit.entity.get('text_content'),
- 'metadata': hit.entity.get('metadata'),
- 'distance': hit.distance,
- 'similarity': 1 - hit.distance
- })
-
- return formatted_results
-
- except Exception as e:
- _get_logger().error(f"Error searching with filter: {e}")
- return []
-
-
- def _create_filter(self, filters: Dict[str, Any]) -> str:
- """
- 创建过滤条件
- """
- # 构建过滤表达式
- filter_expr = ""
- if filters:
- conditions = []
- for key, value in filters.items():
- if isinstance(value, str):
- conditions.append(f'metadata["{key}"] == "{value}"')
- elif isinstance(value, (int, float)):
- conditions.append(f'metadata["{key}"] == {value}')
- else:
- conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
- filter_expr = " and ".join(conditions)
-
- return filter_expr
- def create_hybrid_collection(self, collection_name: str, documents: List[Dict[str, Any]]):
- """
- 创建支持混合搜索的集合
- Args:
- collection_name: 集合名称
- documents: 文档列表,格式: [{'content': '...', 'metadata': {...}}, ...]
- """
- try:
- # 构建连接参数
- connection_args = {
- "uri": f"http://{self.host}:{self.port}",
- "user": self.user,
- "db_name": "lq_db"
- }
- if self.password:
- connection_args["password"] = self.password
- # 转换为 LangChain Document 格式
- langchain_docs = []
- for doc in documents:
- content = doc.get('content', '')
- metadata = doc.get('metadata', {})
- langchain_doc = Document(page_content=content, metadata=metadata)
- langchain_docs.append(langchain_doc)
- # 创建混合搜索向量存储
- vectorstore = Milvus.from_documents(
- documents=langchain_docs,
- embedding=self.emdmodel,
- builtin_function=BM25BuiltInFunction(),
- vector_field=["dense", "sparse"],
- connection_args=connection_args,
- collection_name=collection_name,
- consistency_level="Strong",
- drop_old=True,
- )
- _get_logger().info(f"Created hybrid collection: {collection_name} with {len(documents)} documents")
- return vectorstore
- except Exception as e:
- _get_logger().error(f"Error creating hybrid collection: {e}")
- _get_logger().info("Falling back to traditional vector search")
- return None
- def hybrid_search(self, param: Dict[str, Any], query_text: str,
- top_k: int = 5, ranker_type: str = "weighted",
- dense_weight: float = 0.7, sparse_weight: float = 0.3):
- """
- 混合搜索(参考 test_hybrid_v2.6.py 的实现)
- Args:
- param: 包含collection_name的参数字典
- query_text: 查询文本
- top_k: 返回结果数量
- ranker_type: 重排序类型 "weighted" 或 "rrf"
- dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
- sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
- Returns:
- List[Dict]: 搜索结果列表
- """
- try:
- collection_name = param.get('collection_name')
- # 连接到现有集合
- connection_args = {
- "uri": f"http://{self.host}:{self.port}",
- "user": self.user,
- "db_name": "lq_db"
- }
- if self.password:
- connection_args["password"] = self.password
- vectorstore = Milvus(
- embedding_function=self.emdmodel,
- collection_name=collection_name,
- connection_args=connection_args,
- consistency_level="Strong",
- builtin_function=BM25BuiltInFunction(),
- vector_field=["dense", "sparse"]
- )
- # 执行混合搜索
- if ranker_type == "weighted":
- results = vectorstore.similarity_search(
- query=query_text,
- k=top_k,
- ranker_type="weighted",
- ranker_params={"weights": [dense_weight, sparse_weight]}
- )
- else: # rrf
- results = vectorstore.similarity_search(
- query=query_text,
- k=top_k,
- ranker_type="rrf",
- ranker_params={"k": 60}
- )
- # 格式化结果,保持与其他搜索方法一致
- formatted_results = []
- for doc in results:
- formatted_results.append({
- 'id': doc.metadata.get('pk', 0),
- 'text_content': doc.page_content,
- 'metadata': doc.metadata,
- 'distance': 0.0,
- 'similarity': 1.0
- })
- _get_logger().info(f"Hybrid search returned {len(formatted_results)} results")
- return formatted_results
- except Exception as e:
- _get_logger().error(f"Error in hybrid search: {e}")
- # 回退到传统的向量搜索
- _get_logger().info("Falling back to traditional vector search")
- return self.similarity_search(param, query_text, top_k=top_k)
-
|