reranker.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import re
  3. from typing import Dict, List, Sequence, Optional, Any
  4. from botocore.config import Config
  5. from langchain_aws import BedrockRerank
  6. from langchain_core.callbacks import Callbacks
  7. from langchain_core.documents import BaseDocumentCompressor, Document
  8. from pydantic import ConfigDict
  9. from models_provider.base_model_provider import MaxKBBaseModel
  10. from models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
  11. class BedrockRerankerModel(MaxKBBaseModel, BaseDocumentCompressor):
  12. model_config = ConfigDict(arbitrary_types_allowed=True)
  13. model_id: Optional[str] = None
  14. model_arn: Optional[str] = None
  15. region_name: Optional[str] = None
  16. credentials_profile_name: Optional[str] = None
  17. aws_access_key_id: Optional[str] = None
  18. aws_secret_access_key: Optional[str] = None
  19. config: Optional[Any] = None
  20. top_n: Optional[int] = 3
  21. @staticmethod
  22. def is_cache_model():
  23. return False
  24. @staticmethod
  25. def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str],
  26. **model_kwargs) -> 'BedrockRerankerModel':
  27. top_n = model_kwargs.get('top_n', 3)
  28. region_name = model_credential['region_name']
  29. model_arn = f"arn:aws:bedrock:{region_name}::foundation-model/{model_name}"
  30. config = None
  31. if 'base_url' in model_credential and model_credential['base_url']:
  32. proxy_url = model_credential['base_url']
  33. config = Config(
  34. proxies={
  35. 'http': proxy_url,
  36. 'https': proxy_url
  37. },
  38. connect_timeout=60,
  39. read_timeout=60
  40. )
  41. _update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
  42. model_credential['secret_access_key'])
  43. return BedrockRerankerModel(
  44. model_id=model_name,
  45. model_arn=model_arn,
  46. region_name=region_name,
  47. credentials_profile_name=model_credential['access_key_id'],
  48. aws_access_key_id=model_credential['access_key_id'],
  49. aws_secret_access_key=model_credential['secret_access_key'],
  50. config=config,
  51. top_n=top_n
  52. )
  53. def compress_documents(self, documents: Sequence[Document], query: str,
  54. callbacks: Optional[Callbacks] = None) -> Sequence[Document]:
  55. """Compress documents using Bedrock reranking."""
  56. if not documents:
  57. return []
  58. reranker = BedrockRerank(
  59. model_arn=self.model_arn,
  60. region_name=self.region_name,
  61. credentials_profile_name=self.credentials_profile_name,
  62. aws_access_key_id=self.aws_access_key_id,
  63. aws_secret_access_key=self.aws_secret_access_key,
  64. config=self.config,
  65. top_n=self.top_n
  66. )
  67. return reranker.compress_documents(documents, query, callbacks)