| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- # coding=utf-8
- """
- @project: maxkb
- @Author:虎
- @file: base_vector.py
- @date:2023/10/18 19:16
- @desc:
- """
- import re
- import threading
- from abc import ABC, abstractmethod
- from functools import reduce
- from typing import List, Dict
- from langchain_core.embeddings import Embeddings
- from common.chunk import text_to_chunk
- from common.utils.common import sub_array
- from knowledge.models import SourceType, SearchMode
- lock = threading.Lock()
- def chunk_data(data: Dict):
- if str(data.get('source_type')) == str(SourceType.PARAGRAPH.value):
- text = data.get('text')
- chunk_list = data.get('chunks') if data.get('chunks') else text_to_chunk(text)
- return [{**data, 'text': chunk} for chunk in chunk_list]
- return [data]
- def chunk_data_list(data_list: List[Dict]):
- result = [chunk_data(data) for data in data_list]
- return reduce(lambda x, y: [*x, *y], result, [])
- # 预编译正则,性能更好
- RE_EMOJI = re.compile(
- r"[\U0001F300-\U0001FAFF]" # Emoji
- r"|[\u2600-\u27BF]" # Dingbats / Symbols(⚓ 在这)
- r"|[\uFE0E\uFE0F]", # Variation Selectors
- flags=re.UNICODE
- )
- RE_WHITESPACE = re.compile(r"\s+")
- def normalize_for_embedding(text: str) -> str:
- if not text:
- return ""
- text = RE_EMOJI.sub("", text)
- text = RE_WHITESPACE.sub(" ", text)
- return text.strip()
- class BaseVectorStore(ABC):
- vector_exists = False
- @abstractmethod
- def vector_is_create(self) -> bool:
- """
- 判断向量库是否创建
- :return: 是否创建向量库
- """
- pass
- @abstractmethod
- def vector_create(self):
- """
- 创建 向量库
- :return:
- """
- pass
- def save_pre_handler(self):
- """
- 插入前置处理器 主要是判断向量库是否创建
- :return: True
- """
- if not BaseVectorStore.vector_exists:
- if not self.vector_is_create():
- self.vector_create()
- BaseVectorStore.vector_exists = True
- return True
- def save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
- source_id: str,
- is_active: bool,
- embedding: Embeddings):
- """
- 插入向量数据
- :param source_id: 资源id
- :param knowledge_id: 知识库id
- :param text: 文本
- :param source_type: 资源类型
- :param document_id: 文档id
- :param is_active: 是否禁用
- :param embedding: 向量化处理器
- :param paragraph_id 段落id
- :return: bool
- """
- self.save_pre_handler()
- data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'knowledge_id': knowledge_id,
- 'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text}
- chunk_list = chunk_data(data)
- result = sub_array(chunk_list)
- for child_array in result:
- self._batch_save(child_array, embedding, lambda: False)
- def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
- """
- 批量插入
- @param data_list: 数据列表
- @param embedding: 向量化处理器
- @param is_the_task_interrupted: 判断是否中断任务
- :return: bool
- """
- self.save_pre_handler()
- chunk_list = chunk_data_list(data_list)
- result = sub_array(chunk_list)
- for child_array in result:
- if not is_the_task_interrupted():
- self._batch_save(child_array, embedding, is_the_task_interrupted)
- else:
- break
- @abstractmethod
- def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str,
- source_id: str,
- is_active: bool,
- embedding: Embeddings):
- pass
- @abstractmethod
- def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
- pass
- def search(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str],
- exclude_paragraph_list: list[str],
- is_active: bool,
- embedding: Embeddings):
- if knowledge_id_list is None or len(knowledge_id_list) == 0:
- return []
- query_text = normalize_for_embedding(query_text)
- embedding_query = embedding.embed_query(query_text)
- result = self.query(embedding_query, knowledge_id_list, exclude_document_id_list, exclude_paragraph_list,
- is_active, 1, 3, 0.65)
- return result[0]
- @abstractmethod
- def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
- document_id_list: list[str] | None,
- exclude_document_id_list: list[str],
- exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
- search_mode: SearchMode):
- pass
- @abstractmethod
- def hit_test(self, query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int,
- similarity: float,
- search_mode: SearchMode,
- embedding: Embeddings):
- pass
- @abstractmethod
- def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
- pass
- @abstractmethod
- def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict):
- pass
- @abstractmethod
- def update_by_source_id(self, source_id: str, instance: Dict):
- pass
- @abstractmethod
- def update_by_source_ids(self, source_ids: List[str], instance: Dict):
- pass
- @abstractmethod
- def delete_by_knowledge_id(self, knowledge_id: str):
- pass
- @abstractmethod
- def delete_by_document_id(self, document_id: str):
- pass
- @abstractmethod
- def delete_by_document_id_list(self, document_id_list: List[str]):
- pass
- @abstractmethod
- def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
- pass
- @abstractmethod
- def delete_by_source_id(self, source_id: str, source_type: str):
- pass
- @abstractmethod
- def delete_by_source_ids(self, source_ids: List[str], source_type: str):
- pass
- @abstractmethod
- def delete_by_paragraph_id(self, paragraph_id: str):
- pass
- @abstractmethod
- def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
- pass
|