generate.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from dataclasses import dataclass
  2. import os
  3. import shutil
  4. from typing import List
  5. from jinja2 import Environment, FileSystemLoader
  6. from .filters import to_dash_plural, to_snake_case, to_plural, to_underscore_plural
  7. def main():
  8. cfg = Config(
  9. class_names=[
  10. "Worker",
  11. "Model",
  12. "ModelInstance",
  13. "ModelFile",
  14. "User",
  15. "InferenceBackend",
  16. "Benchmark",
  17. "ModelRouteTarget",
  18. ]
  19. )
  20. env = Environment(loader=FileSystemLoader(cfg.template_dir), auto_reload=True)
  21. env.filters["to_snake_case"] = to_snake_case
  22. env.filters["to_plural"] = to_plural
  23. env.filters["to_underscore_plural"] = to_underscore_plural
  24. env.filters["to_dash_plural"] = to_dash_plural
  25. reset(cfg)
  26. gen_http_clients(env, cfg)
  27. gen_clients(env, cfg)
  28. gen_clientset(env, cfg)
  29. write_init(cfg)
  30. print("Code gen succeeded!")
  31. @dataclass
  32. class Config:
  33. template_dir: str = os.path.join(os.path.dirname(__file__), "templates")
  34. output_dir: str = "gpustack/client"
  35. class_names: List[str] = None
  36. def gen_clients(env: Environment, cfg: Config):
  37. template = env.get_template("client.py.jinja")
  38. for class_name in cfg.class_names:
  39. data = {
  40. "class_name": class_name,
  41. }
  42. client_code = template.render(data)
  43. with open(
  44. f"{cfg.output_dir}/generated_{to_snake_case(class_name)}_client.py", "w"
  45. ) as f:
  46. f.write(client_code)
  47. def gen_clientset(env: Environment, cfg: Config):
  48. template = env.get_template("clientset.py.jinja")
  49. data = {
  50. "class_names": cfg.class_names,
  51. }
  52. client_code = template.render(data)
  53. with open(f"{cfg.output_dir}/generated_clientset.py", "w") as f:
  54. f.write(client_code)
  55. def gen_http_clients(env: Environment, cfg: Config):
  56. shutil.copyfile(
  57. f"{cfg.template_dir}/http_client.py.jinja",
  58. f"{cfg.output_dir}/generated_http_client.py",
  59. )
  60. def write_init(cfg: Config):
  61. with open(f"{cfg.output_dir}/__init__.py", "w") as f:
  62. f.write(
  63. """from .generated_clientset import ClientSet
  64. __all__ = ["ClientSet"]
  65. """
  66. )
  67. def reset(cfg: Config):
  68. output_dir = cfg.output_dir
  69. if os.path.exists(output_dir):
  70. for file in os.listdir(output_dir):
  71. if file == "__init__.py" or file.startswith("generated_"):
  72. os.remove(os.path.join(output_dir, file))
  73. if not os.path.exists(output_dir):
  74. os.makedirs(output_dir)
  75. if __name__ == "__main__":
  76. main()