graph_models.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. """
  2. 图数据模型定义
  3. 提供知识图谱相关的通用数据结构定义
  4. """
  5. from typing import Optional, Dict, Any, List, Union
  6. from dataclasses import dataclass
  7. from datetime import datetime
  8. from enum import Enum
  9. class NodeType(Enum):
  10. """节点类型枚举"""
  11. PERSON = "person"
  12. ORGANIZATION = "organization"
  13. LOCATION = "location"
  14. CONCEPT = "concept"
  15. EVENT = "event"
  16. DOCUMENT = "document"
  17. UNKNOWN = "unknown"
  18. class RelationType(Enum):
  19. """关系类型枚举"""
  20. BELONGS_TO = "belongs_to"
  21. LOCATED_IN = "located_in"
  22. RELATED_TO = "related_to"
  23. PART_OF = "part_of"
  24. INSTANCE_OF = "instance_of"
  25. KNOWS = "knows"
  26. WORKS_FOR = "works_for"
  27. UNKNOWN = "unknown"
  28. @dataclass
  29. class GraphNode:
  30. """图节点数据模型"""
  31. id: Optional[str] = None
  32. label: str = ""
  33. node_type: NodeType = NodeType.UNKNOWN
  34. properties: Optional[Dict[str, Any]] = None
  35. embeddings: Optional[List[float]] = None
  36. created_at: Optional[datetime] = None
  37. updated_at: Optional[datetime] = None
  38. def __post_init__(self):
  39. if self.properties is None:
  40. self.properties = {}
  41. if isinstance(self.node_type, str):
  42. self.node_type = NodeType(self.node_type)
  43. def to_dict(self) -> Dict[str, Any]:
  44. """转换为字典"""
  45. return {
  46. 'id': self.id,
  47. 'label': self.label,
  48. 'node_type': self.node_type.value if self.node_type else None,
  49. 'properties': self.properties,
  50. 'embeddings': self.embeddings,
  51. 'created_at': self.created_at.isoformat() if self.created_at else None,
  52. 'updated_at': self.updated_at.isoformat() if self.updated_at else None
  53. }
  54. @classmethod
  55. def from_dict(cls, data: Dict[str, Any]) -> 'GraphNode':
  56. """从字典创建实例"""
  57. node_type = data.get('node_type')
  58. if isinstance(node_type, str):
  59. node_type = NodeType(node_type)
  60. return cls(
  61. id=data.get('id'),
  62. label=data.get('label', ''),
  63. node_type=node_type,
  64. properties=data.get('properties', {}),
  65. embeddings=data.get('embeddings', []),
  66. created_at=datetime.fromisoformat(data['created_at']) if data.get('created_at') else None,
  67. updated_at=datetime.fromisoformat(data['updated_at']) if data.get('updated_at') else None
  68. )
  69. @dataclass
  70. class GraphEdge:
  71. """图边数据模型"""
  72. id: Optional[str] = None
  73. source_id: str = ""
  74. target_id: str = ""
  75. relation_type: RelationType = RelationType.UNKNOWN
  76. weight: float = 1.0
  77. properties: Optional[Dict[str, Any]] = None
  78. created_at: Optional[datetime] = None
  79. def __post_init__(self):
  80. if self.properties is None:
  81. self.properties = {}
  82. if isinstance(self.relation_type, str):
  83. self.relation_type = RelationType(self.relation_type)
  84. def to_dict(self) -> Dict[str, Any]:
  85. """转换为字典"""
  86. return {
  87. 'id': self.id,
  88. 'source_id': self.source_id,
  89. 'target_id': self.target_id,
  90. 'relation_type': self.relation_type.value if self.relation_type else None,
  91. 'weight': self.weight,
  92. 'properties': self.properties,
  93. 'created_at': self.created_at.isoformat() if self.created_at else None
  94. }
  95. @classmethod
  96. def from_dict(cls, data: Dict[str, Any]) -> 'GraphEdge':
  97. """从字典创建实例"""
  98. relation_type = data.get('relation_type')
  99. if isinstance(relation_type, str):
  100. relation_type = RelationType(relation_type)
  101. return cls(
  102. id=data.get('id'),
  103. source_id=data.get('source_id', ''),
  104. target_id=data.get('target_id', ''),
  105. relation_type=relation_type,
  106. weight=data.get('weight', 1.0),
  107. properties=data.get('properties', {}),
  108. created_at=datetime.fromisoformat(data['created_at']) if data.get('created_at') else None
  109. )
  110. @dataclass
  111. class GraphEntity:
  112. """图实体数据模型(扩展的节点模型)"""
  113. node: GraphNode
  114. entity_type: str = ""
  115. confidence: float = 1.0
  116. source_document: Optional[str] = None
  117. extraction_method: Optional[str] = None
  118. def to_dict(self) -> Dict[str, Any]:
  119. """转换为字典"""
  120. return {
  121. 'node': self.node.to_dict(),
  122. 'entity_type': self.entity_type,
  123. 'confidence': self.confidence,
  124. 'source_document': self.source_document,
  125. 'extraction_method': self.extraction_method
  126. }
  127. @classmethod
  128. def from_dict(cls, data: Dict[str, Any]) -> 'GraphEntity':
  129. """从字典创建实例"""
  130. node_data = data.get('node', {})
  131. node = GraphNode.from_dict(node_data)
  132. return cls(
  133. node=node,
  134. entity_type=data.get('entity_type', ''),
  135. confidence=data.get('confidence', 1.0),
  136. source_document=data.get('source_document'),
  137. extraction_method=data.get('extraction_method')
  138. )
  139. @dataclass
  140. class GraphRelation:
  141. """图关系数据模型(扩展的边模型)"""
  142. edge: GraphEdge
  143. relation_subtype: Optional[str] = None
  144. confidence: float = 1.0
  145. source_sentence: Optional[str] = None
  146. extraction_method: Optional[str] = None
  147. def to_dict(self) -> Dict[str, Any]:
  148. """转换为字典"""
  149. return {
  150. 'edge': self.edge.to_dict(),
  151. 'relation_subtype': self.relation_subtype,
  152. 'confidence': self.confidence,
  153. 'source_sentence': self.source_sentence,
  154. 'extraction_method': self.extraction_method
  155. }
  156. @classmethod
  157. def from_dict(cls, data: Dict[str, Any]) -> 'GraphRelation':
  158. """从字典创建实例"""
  159. edge_data = data.get('edge', {})
  160. edge = GraphEdge.from_dict(edge_data)
  161. return cls(
  162. edge=edge,
  163. relation_subtype=data.get('relation_subtype'),
  164. confidence=data.get('confidence', 1.0),
  165. source_sentence=data.get('source_sentence'),
  166. extraction_method=data.get('extraction_method')
  167. )
  168. @dataclass
  169. class KnowledgeGraph:
  170. """知识图谱数据模型"""
  171. id: Optional[str] = None
  172. name: str = ""
  173. description: Optional[str] = None
  174. nodes: List[GraphEntity] = None
  175. relations: List[GraphRelation] = None
  176. metadata: Optional[Dict[str, Any]] = None
  177. created_at: Optional[datetime] = None
  178. updated_at: Optional[datetime] = None
  179. def __post_init__(self):
  180. if self.nodes is None:
  181. self.nodes = []
  182. if self.relations is None:
  183. self.relations = []
  184. if self.metadata is None:
  185. self.metadata = {}
  186. def to_dict(self) -> Dict[str, Any]:
  187. """转换为字典"""
  188. return {
  189. 'id': self.id,
  190. 'name': self.name,
  191. 'description': self.description,
  192. 'nodes': [node.to_dict() for node in self.nodes],
  193. 'relations': [relation.to_dict() for relation in self.relations],
  194. 'metadata': self.metadata,
  195. 'created_at': self.created_at.isoformat() if self.created_at else None,
  196. 'updated_at': self.updated_at.isoformat() if self.updated_at else None
  197. }
  198. @classmethod
  199. def from_dict(cls, data: Dict[str, Any]) -> 'KnowledgeGraph':
  200. """从字典创建实例"""
  201. nodes_data = data.get('nodes', [])
  202. relations_data = data.get('relations', [])
  203. nodes = [GraphEntity.from_dict(node_data) for node_data in nodes_data]
  204. relations = [GraphRelation.from_dict(relation_data) for relation_data in relations_data]
  205. return cls(
  206. id=data.get('id'),
  207. name=data.get('name', ''),
  208. description=data.get('description'),
  209. nodes=nodes,
  210. relations=relations,
  211. metadata=data.get('metadata', {}),
  212. created_at=datetime.fromisoformat(data['created_at']) if data.get('created_at') else None,
  213. updated_at=datetime.fromisoformat(data['updated_at']) if data.get('updated_at') else None
  214. )
  215. __all__ = [
  216. "NodeType",
  217. "RelationType",
  218. "GraphNode",
  219. "GraphEdge",
  220. "GraphEntity",
  221. "GraphRelation",
  222. "KnowledgeGraph"
  223. ]