model_apply_serializers.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: model_apply_serializers.py
  6. @date:2024/8/20 20:39
  7. @desc:
  8. """
  9. from django.db import connection
  10. from django.db.models import QuerySet
  11. from langchain_core.documents import Document
  12. from rest_framework import serializers
  13. from common.config.embedding_config import ModelManage
  14. from django.utils.translation import gettext_lazy as _
  15. from models_provider.models import Model
  16. from models_provider.tools import get_model
  17. def get_embedding_model(model_id):
  18. model = QuerySet(Model).filter(id=model_id).first()
  19. # 手动关闭数据库连接
  20. connection.close()
  21. embedding_model = ModelManage.get_model(model_id,
  22. lambda _id: get_model(model, use_local=True))
  23. return embedding_model
  24. class EmbedDocuments(serializers.Serializer):
  25. texts = serializers.ListField(required=True, child=serializers.CharField(required=True,
  26. label=_('vector text')),
  27. label=_('vector text list')),
  28. class EmbedQuery(serializers.Serializer):
  29. text = serializers.CharField(required=True, label=_('vector text'))
  30. class CompressDocument(serializers.Serializer):
  31. page_content = serializers.CharField(required=True, label=_('text'))
  32. metadata = serializers.DictField(required=False, label=_('metadata'))
  33. class CompressDocuments(serializers.Serializer):
  34. documents = CompressDocument(required=True, many=True)
  35. query = serializers.CharField(required=True, label=_('query'))
  36. class ModelApplySerializers(serializers.Serializer):
  37. model_id = serializers.UUIDField(required=True, label=_('model id'))
  38. def embed_documents(self, instance, with_valid=True):
  39. if with_valid:
  40. self.is_valid(raise_exception=True)
  41. EmbedDocuments(data=instance).is_valid(raise_exception=True)
  42. model = get_embedding_model(self.data.get('model_id'))
  43. return model.embed_documents(instance.getlist('texts'))
  44. def embed_query(self, instance, with_valid=True):
  45. if with_valid:
  46. self.is_valid(raise_exception=True)
  47. EmbedQuery(data=instance).is_valid(raise_exception=True)
  48. model = get_embedding_model(self.data.get('model_id'))
  49. return model.embed_query(instance.get('text'))
  50. def compress_documents(self, instance, with_valid=True):
  51. if with_valid:
  52. self.is_valid(raise_exception=True)
  53. CompressDocuments(data=instance).is_valid(raise_exception=True)
  54. model = get_embedding_model(self.data.get('model_id'))
  55. return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
  56. [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
  57. instance.get('documents')], instance.get('query'))]