process.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import psutil
  2. import asyncio
  3. import logging
  4. import signal
  5. import os
  6. import threading
  7. from gpustack.utils import platform
  8. logger = logging.getLogger(__name__)
  9. threading_stop_event = threading.Event()
  10. termination_signal_handled = False
  11. def add_signal_handlers():
  12. signal.signal(signal.SIGTERM, handle_termination_signal)
  13. def add_signal_handlers_in_loop():
  14. if platform.system() == "windows":
  15. # Windows does not support asyncio signal handlers.
  16. add_signal_handlers()
  17. return
  18. loop = asyncio.get_event_loop()
  19. for sig in (signal.SIGINT, signal.SIGTERM):
  20. logger.debug(f"Adding signal handler for {sig}")
  21. loop.add_signal_handler(
  22. sig, lambda: asyncio.create_task(shutdown_event_loop(sig, loop))
  23. )
  24. async def shutdown_event_loop(sig=None, loop=None):
  25. logger.debug(f"Received signal: {sig}. Shutting down gracefully...")
  26. threading_stop_event.set()
  27. try:
  28. tasks = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()]
  29. for task in tasks:
  30. task.cancel()
  31. # Wait for all tasks to complete
  32. await asyncio.gather(*tasks, return_exceptions=True)
  33. except asyncio.CancelledError:
  34. pass
  35. handle_termination_signal(sig=sig)
  36. def handle_termination_signal(sig=None, frame=None):
  37. """
  38. Terminate the current process and all its children.
  39. """
  40. global termination_signal_handled
  41. if termination_signal_handled:
  42. return
  43. termination_signal_handled = True
  44. threading_stop_event.set()
  45. pid = os.getpid()
  46. terminate_process_tree(pid)
  47. def terminate_process_tree(pid: int):
  48. try:
  49. process = psutil.Process(pid)
  50. children = process.children(recursive=True)
  51. # Terminate all child processes
  52. terminate_processes(children)
  53. # Terminate the parent process
  54. terminate_process(process)
  55. except psutil.NoSuchProcess:
  56. pass
  57. except Exception as e:
  58. logger.error(f"Error while terminating process tree: {e}")
  59. def terminate_processes(processes):
  60. """
  61. Terminates a list of processes, attempting graceful termination first,
  62. then forcibly killing remaining ones if necessary.
  63. """
  64. for process in processes:
  65. try:
  66. process.terminate()
  67. except psutil.NoSuchProcess:
  68. continue
  69. # Wait for processes to terminate and kill if still alive
  70. _, alive_processes = psutil.wait_procs(processes, timeout=3)
  71. while alive_processes:
  72. for process in alive_processes:
  73. try:
  74. process.kill()
  75. except psutil.NoSuchProcess:
  76. continue
  77. _, alive_processes = psutil.wait_procs(alive_processes, timeout=1)
  78. def terminate_process(process):
  79. """
  80. Terminates a single process, attempting graceful termination first,
  81. then forcibly killing it if necessary.
  82. """
  83. if process.is_running():
  84. try:
  85. process.terminate()
  86. process.wait(timeout=3)
  87. except psutil.NoSuchProcess:
  88. pass
  89. except psutil.TimeoutExpired:
  90. try:
  91. process.kill()
  92. process.wait(timeout=1)
  93. except (psutil.NoSuchProcess, psutil.TimeoutExpired):
  94. pass