milvus.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from __future__ import annotations
  2. from dataclasses import dataclass, field
  3. from typing import Any, Dict, List, Optional, Sequence
  4. from pymilvus import MilvusClient
  5. from langchain_core.documents import Document
  6. from foundation.infrastructure.config.config import config_handler
  7. @dataclass(frozen=True)
  8. class MilvusConfig:
  9. """
  10. 连接配置:uri / db_name 从配置读取
  11. """
  12. uri: str = field(
  13. default_factory=lambda: (
  14. f"http://{config_handler.get('milvus', 'MILVUS_HOST', 'localhost')}:"
  15. f"{int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))}"
  16. )
  17. )
  18. db_name:str=config_handler.get('milvus', 'MILVUS_DB', 'lq_db')
  19. class MilvusManager:
  20. """
  21. 基于 pymilvus.MilvusClient 的管理类(不使用 langchain-milvus):
  22. - 初始化:创建 client,并 use_database(db_name)
  23. - 查询:每次传 collection_name(不固定)
  24. - 提供:
  25. 1) condition_query:纯条件查询(MilvusClient.query)
  26. """
  27. def __init__(self, cfg: MilvusConfig):
  28. self.cfg = cfg
  29. self.client = MilvusClient(uri=self.cfg.uri)
  30. self.client.use_database(self.cfg.db_name)
  31. # 约定字段名(按你们 schema 调整)
  32. self.text_field = "text"
  33. def list_collections(self) -> List[str]:
  34. return self.client.list_collections()
  35. def condition_query(
  36. self,
  37. *,
  38. collection_name: str,
  39. filter: str,
  40. output_fields: Optional[Sequence[str]] = None,
  41. limit: Optional[int] = None,
  42. ) -> List[Dict[str, Any]]:
  43. """
  44. filter 示例:
  45. parent_id == 'xxx'
  46. tenant == 't1' and source == 'pdf'
  47. output_fields 示例:
  48. ["text"]
  49. ["text", "parent_id", "chunk_id"]
  50. """
  51. if not collection_name:
  52. raise ValueError("collection_name 不能为空")
  53. if output_fields is None:
  54. output_fields = [self.text_field]
  55. # 提前校验,避免直接抛 MilvusException 且不直观
  56. if not self.client.has_collection(collection_name):
  57. existing = self.client.list_collections()
  58. raise RuntimeError(
  59. f"collection not found: {collection_name}\n"
  60. f"current db_name={self.cfg.db_name}, uri={self.cfg.uri}\n"
  61. f"collections in current db: {existing}"
  62. )
  63. rows = self.client.query(
  64. collection_name=collection_name,
  65. filter=filter,
  66. output_fields=list(output_fields),
  67. limit=limit,
  68. )
  69. return rows
  70. if __name__ == "__main__":
  71. mv = MilvusManager(MilvusConfig())
  72. docs = mv.condition_query(
  73. collection_name="rag_parent_hybrid",
  74. filter="parent_id == '02267e1d-11d7-4a3d-b53f-e205edd6758f'",
  75. limit=10,
  76. )
  77. print(docs)