authenticator.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """
  2. Authenticator for server-to-server federation.
  3. """
  4. import hmac
  5. import hashlib
  6. from abc import ABC, abstractmethod
  7. from typing import Optional, Dict
  8. from fastapi import WebSocket
  9. class Authenticator(ABC):
  10. """Abstract base class for authenticator implementations."""
  11. @abstractmethod
  12. def inject_headers(
  13. self,
  14. headers: Dict[str, str],
  15. ) -> None:
  16. """Inject auth headers into the headers dict for an outgoing connection."""
  17. @abstractmethod
  18. async def authenticate(self, websocket: WebSocket) -> bool:
  19. """Verify the signature from an incoming connection."""
  20. class HMACAuthenticator(Authenticator):
  21. """HMAC-SHA256 based authenticator."""
  22. def __init__(
  23. self,
  24. key: str,
  25. header_key: str = 'x-server-id',
  26. signature_header: str = 'x-auth-signature',
  27. ) -> None:
  28. self.key = key
  29. self.header_key = header_key
  30. self.signature_header = signature_header
  31. def inject_headers(
  32. self,
  33. headers: Dict[str, str],
  34. ) -> None:
  35. """Inject HMAC auth signature into headers."""
  36. server_id_str = headers.get(self.header_key, '')
  37. if server_id_str == '':
  38. raise ValueError("Missing server ID in headers for HMAC authentication")
  39. signature = hmac.new(
  40. self.key.encode(), server_id_str.encode(), hashlib.sha256
  41. ).hexdigest()
  42. headers[self.signature_header] = signature
  43. async def authenticate(self, websocket: WebSocket) -> bool:
  44. """Verify the signature from an incoming connection."""
  45. headers = websocket.headers
  46. provided = headers.get(self.signature_header, '')
  47. if not provided:
  48. return False
  49. server_id_str = headers.get(self.header_key, '')
  50. if not server_id_str:
  51. return False
  52. expected = hmac.new(
  53. self.key.encode(), server_id_str.encode(), hashlib.sha256
  54. ).hexdigest()
  55. return hmac.compare_digest(provided, expected)
  56. class NoOpAuthenticator(Authenticator):
  57. """Authenticator that accepts all connections (no auth)."""
  58. def inject_headers(
  59. self,
  60. _headers: Dict[str, str],
  61. ) -> None:
  62. """No-op: does not inject any headers."""
  63. pass
  64. async def authenticate(self, _websocket: WebSocket) -> bool:
  65. return True
  66. def create_authenticator(key: Optional[str]) -> Authenticator:
  67. """Factory to create an authenticator based on whether a key is provided."""
  68. if key:
  69. return HMACAuthenticator(key)
  70. return NoOpAuthenticator()