embedding.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: embedding.py
  6. @date:2024/10/17 15:29
  7. @desc:
  8. """
  9. import base64
  10. import json
  11. from typing import Dict, Optional
  12. from langchain_community.embeddings import SparkLLMTextEmbeddings
  13. from numpy import ndarray
  14. from models_provider.base_model_provider import MaxKBBaseModel
  15. import time
  16. import json
  17. import base64
  18. import numpy as np
  19. import threading
  20. import queue
  21. _task_queue = queue.Queue()
  22. def _worker():
  23. while True:
  24. message, future = _task_queue.get()
  25. for i in range(3):
  26. try:
  27. data = json.loads(message)
  28. code = data["header"]["code"]
  29. if code != 0:
  30. raise Exception(f"Request error: {code}, {data}")
  31. text_base = data["payload"]["feature"]["text"]
  32. text_data = base64.b64decode(text_base)
  33. dt = np.dtype(np.float32)
  34. dt = dt.newbyteorder("<")
  35. text = np.frombuffer(text_data, dtype=dt)
  36. if len(text) > 2560:
  37. array = text[:2560]
  38. else:
  39. array = text
  40. future["result"] = array
  41. future["event"].set()
  42. break
  43. except Exception as e:
  44. if i == 2:
  45. future["error"] = e
  46. future["event"].set()
  47. else:
  48. time.sleep(0.5)
  49. time.sleep(0.5) # QPS=2
  50. threading.Thread(target=_worker, daemon=True).start()
  51. class XFEmbedding(MaxKBBaseModel, SparkLLMTextEmbeddings):
  52. @staticmethod
  53. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  54. return XFEmbedding(
  55. base_url=model_credential.get('base_url'),
  56. spark_app_id=model_credential.get('spark_app_id'),
  57. spark_api_key=model_credential.get('spark_api_key'),
  58. spark_api_secret=model_credential.get('spark_api_secret')
  59. )
  60. @staticmethod
  61. def _parser_message(
  62. message: str,
  63. ) -> Optional[ndarray]:
  64. future = {
  65. "event": threading.Event(),
  66. "result": None,
  67. "error": None
  68. }
  69. _task_queue.put((message, future))
  70. future["event"].wait()
  71. if future["error"]:
  72. raise future["error"]
  73. return future["result"]