config.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2023/3/2 10:28
  4. # file: config.py
  5. # author: David
  6. # company: shenyang JY
  7. """
  8. 模型调参及系统功能配置
  9. """
  10. import os, threading
  11. import argparse, yaml
  12. from app.common.logs import Log
  13. class myargparse(argparse.ArgumentParser):
  14. _save_lock = threading.Lock()
  15. def __init__(self, description, add_help):
  16. super(myargparse, self).__init__(description=description, add_help=add_help)
  17. self.add_argument(
  18. 'input_file',
  19. type=str,
  20. metavar='FILE',
  21. help='训练预测数据路径')
  22. self.add_argument(
  23. 'moment',
  24. type=str,
  25. metavar='FILE',
  26. help='时刻')
  27. self.add_argument(
  28. '-c',
  29. '--config_yaml',
  30. default=
  31. 'config.yml',
  32. type=str,
  33. metavar='FILE',
  34. help='YAML config file specifying default arguments')
  35. self.add_argument(
  36. '-n',
  37. '--neu_yaml',
  38. default=
  39. 'neu.yml',
  40. type=str,
  41. metavar='FILE',
  42. help='YAML config file specifying default arguments')
  43. self.add_argument(
  44. '-m',
  45. '--model_name',
  46. type=str,
  47. metavar='FILE',
  48. help='模型选择')
  49. self.add_argument(
  50. '-o',
  51. '--train_mode',
  52. default=False,
  53. type=bool,
  54. metavar='train mode',
  55. help='训练')
  56. def _parse_args_and_yaml(self):
  57. base_parser = argparse.ArgumentParser(add_help=False)
  58. base_parser.add_argument('-c', '--config_yaml', default='config.yml', type=str)
  59. base_parser.add_argument('-n', '--neu_yaml', default='neu.yml', type=str)
  60. given_configs, remaining = base_parser.parse_known_args()
  61. current_path = os.path.dirname(os.path.dirname(__file__))
  62. if given_configs.config_yaml:
  63. config_path = os.path.join(current_path, 'common', given_configs.config_yaml)
  64. with open(config_path, 'r', encoding='utf-8') as f:
  65. cfg = yaml.safe_load(f)
  66. self.set_defaults(**cfg)
  67. if given_configs.neu_yaml:
  68. config_path = os.path.join(current_path, 'common', given_configs.neu_yaml)
  69. with open(config_path, 'r', encoding='utf-8') as f:
  70. model_cfg = yaml.safe_load(f)
  71. self.set_defaults(**model_cfg)
  72. # defaults will have been overridden if config file specified.
  73. # opt = self.parse_args(remaining)
  74. opt = self.parse_args(remaining)
  75. # Cache the args as a text string to save them in the output dir later
  76. opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False)
  77. return opt, opt_text
  78. def parse_args_and_yaml(self):
  79. return self._parse_args_and_yaml()[0]
  80. def save_args_yml(self, opt):
  81. current_path = os.path.dirname(__file__)
  82. with myargparse._save_lock:
  83. file_path = os.path.join(current_path, 'config.yml')
  84. with open(file_path, mode='w', encoding='utf-8') as f:
  85. yaml.safe_dump(vars(opt), f)
  86. # -------------- 设置全局对象 --------------
  87. # 创建日志对象
  88. logger = Log().logger
  89. # 创建解析器对象
  90. parser = myargparse(description='南网竞赛算法', add_help=False)
  91. if __name__ == '__main__':
  92. pass