milvus_connection.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. """
  2. Milvus向量数据库连接管理
  3. """
  4. import os
  5. import logging
  6. from typing import Optional
  7. # 导入配置
  8. from app.core.config import config_handler
  9. from .embedding_connection import get_embedding_model
  10. from langchain_milvus import Milvus, BM25BuiltInFunction
  11. logger = logging.getLogger(__name__)
  12. _milvus_manager = None
  13. def get_milvus_vectorstore(collection_name: str, consistency_level: str = "Strong"):
  14. """
  15. 获取 Milvus Vectorstore 实例(用于混合搜索)
  16. Args:
  17. collection_name: 集合名称
  18. consistency_level: 一致性级别,默认为 "Strong"
  19. Returns:
  20. Milvus: LangChain 的 Milvus Vectorstore 实例
  21. """
  22. try:
  23. # 直接调用embedding_connection的embedding
  24. embedding_function = get_embedding_model()
  25. manager = get_milvus_manager()
  26. connection_args = {
  27. "uri": f"http://{manager.host}:{manager.port}",
  28. "user": manager.user,
  29. "db_name": manager.db_name
  30. }
  31. if manager.password:
  32. connection_args["password"] = manager.password
  33. # 动态检测向量字段名称,兼容旧集合(vector)和新集合(dense)
  34. vector_field_name = "dense"
  35. try:
  36. desc = manager.client.describe_collection(collection_name)
  37. fields = desc.get("fields", []) if isinstance(desc, dict) else []
  38. float_vector_fields = []
  39. for f in fields:
  40. f_name = f.get("name")
  41. f_type = f.get("type")
  42. if not f_name:
  43. continue
  44. # DataType.FLOAT_VECTOR 在 pymilvus 中通常是 101,字符串形式可能为 "FloatVector"
  45. if f_type == 101 or str(f_type).upper() in ("FLOAT_VECTOR", "FLOATVECTOR"):
  46. float_vector_fields.append(f_name)
  47. # 优先 dense,其次 vector,再次第一个向量字段
  48. if "dense" in float_vector_fields:
  49. vector_field_name = "dense"
  50. elif "vector" in float_vector_fields:
  51. vector_field_name = "vector"
  52. elif float_vector_fields:
  53. vector_field_name = float_vector_fields[0]
  54. except Exception as e:
  55. logger.warning(f"自动检测向量字段失败,使用默认 'dense': {e}")
  56. vectorstore = Milvus(
  57. embedding_function=embedding_function,
  58. collection_name=collection_name,
  59. connection_args=connection_args,
  60. consistency_level=consistency_level,
  61. builtin_function=BM25BuiltInFunction(),
  62. vector_field=vector_field_name
  63. )
  64. return vectorstore
  65. except Exception as e:
  66. logger.error(f"获取 Milvus Vectorstore 失败: {e}")
  67. raise
  68. class MilvusManager:
  69. """Milvus管理器"""
  70. def __init__(self):
  71. self.host: str = config_handler.get("admin_app", "MILVUS_HOST", "localhost")
  72. self.port: int = config_handler.get_int("admin_app", "MILVUS_PORT", 19530)
  73. self.db_name: str = config_handler.get("admin_app", "MILVUS_DB", "default")
  74. self.user: Optional[str] = config_handler.get("admin_app", "MILVUS_USER", "")
  75. self.password: Optional[str] = config_handler.get("admin_app", "MILVUS_PASSWORD", "")
  76. self.uri = f"http://{self.host}:{self.port}"
  77. logger.info(f"初始化 MilvusClient: uri={self.uri}, db={self.db_name}")
  78. # 延迟初始化 client
  79. self._client = None
  80. @property
  81. def client(self):
  82. """获取 Milvus 客户端(延迟初始化)"""
  83. if self._client is None:
  84. try:
  85. from pymilvus import MilvusClient
  86. self._client = MilvusClient(
  87. uri=self.uri,
  88. user=self.user or "",
  89. password=self.password or "",
  90. db_name=self.db_name,
  91. )
  92. logger.info("Milvus客户端初始化成功")
  93. except Exception as e:
  94. logger.error(f"Milvus客户端初始化失败: {e}")
  95. raise
  96. return self._client
  97. def close(self) -> None:
  98. """关闭 Milvus 连接"""
  99. if self._client:
  100. try:
  101. self._client.close()
  102. logger.info("Milvus连接已关闭")
  103. except Exception as e:
  104. logger.error(f"关闭Milvus连接失败: {e}")
  105. finally:
  106. self._client = None
  107. def get_milvus_manager() -> MilvusManager:
  108. """获取 Milvus 管理器单例"""
  109. global _milvus_manager
  110. if _milvus_manager is None:
  111. _milvus_manager = MilvusManager()
  112. return _milvus_manager
  113. def get_milvus_connection():
  114. """获取Milvus连接(兼容旧接口)"""
  115. try:
  116. return get_milvus_manager().client
  117. except Exception as e:
  118. logger.warning(f"Milvus连接失败: {e}")
  119. return None
  120. async def init_milvus():
  121. """初始化Milvus连接"""
  122. try:
  123. get_milvus_connection()
  124. logger.info("Milvus初始化成功")
  125. except Exception as e:
  126. logger.warning(f"Milvus初始化失败: {e}")
  127. async def close_milvus():
  128. """关闭Milvus连接"""
  129. global _milvus_manager
  130. if _milvus_manager:
  131. try:
  132. _milvus_manager.close()
  133. logger.info("Milvus连接已关闭")
  134. except Exception as e:
  135. logger.error(f"关闭Milvus连接失败: {e}")
  136. finally:
  137. _milvus_manager = None