config.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2023/3/2 10:28
  4. # file: config.py
  5. # author: David
  6. # company: shenyang JY
  7. """
  8. 模型调参及系统功能配置
  9. """
  10. import os, threading
  11. import argparse, yaml
  12. from app.common.logs import Log
  13. class myargparse(argparse.ArgumentParser):
  14. _save_lock = threading.Lock()
  15. def __init__(self, description, add_help):
  16. super(myargparse, self).__init__(description=description, add_help=add_help)
  17. self.add_argument(
  18. '-c',
  19. '--config_yaml',
  20. default=
  21. 'config.yml',
  22. type=str,
  23. metavar='FILE',
  24. help='YAML config file specifying default arguments')
  25. self.add_argument(
  26. '-n',
  27. '--neu_yaml',
  28. default=
  29. 'neu.yml',
  30. type=str,
  31. metavar='FILE',
  32. help='YAML config file specifying default arguments')
  33. self.add_argument(
  34. '-i',
  35. '--input_file',
  36. type=str,
  37. metavar='FILE',
  38. help='训练预测数据路径')
  39. self.add_argument(
  40. '-m',
  41. '--model_name',
  42. type=str,
  43. metavar='FILE',
  44. help='模型选择')
  45. def _parse_args_and_yaml(self):
  46. base_parser = argparse.ArgumentParser(add_help=False)
  47. base_parser.add_argument('-c', '--config_yaml', default='config.yml', type=str)
  48. base_parser.add_argument('-n', '--neu_yaml', default='neu.yml', type=str)
  49. given_configs, remaining = base_parser.parse_known_args()
  50. current_path = os.path.dirname(os.path.dirname(__file__))
  51. if given_configs.config_yaml:
  52. config_path = os.path.join(current_path, 'common', given_configs.config_yaml)
  53. with open(config_path, 'r', encoding='utf-8') as f:
  54. cfg = yaml.safe_load(f)
  55. self.set_defaults(**cfg)
  56. if given_configs.neu_yaml:
  57. config_path = os.path.join(current_path, 'common', given_configs.neu_yaml)
  58. with open(config_path, 'r', encoding='utf-8') as f:
  59. model_cfg = yaml.safe_load(f)
  60. self.set_defaults(**model_cfg)
  61. # defaults will have been overridden if config file specified.
  62. # opt = self.parse_args(remaining)
  63. opt = self.parse_args(remaining)
  64. # Cache the args as a text string to save them in the output dir later
  65. opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False)
  66. return opt, opt_text
  67. def parse_args_and_yaml(self):
  68. return self._parse_args_and_yaml()[0]
  69. def save_args_yml(self, opt):
  70. current_path = os.path.dirname(__file__)
  71. with myargparse._save_lock:
  72. file_path = os.path.join(current_path, 'config.yml')
  73. with open(file_path, mode='w', encoding='utf-8') as f:
  74. yaml.safe_dump(vars(opt), f)
  75. # -------------- 设置全局对象 --------------
  76. # 创建日志对象
  77. logger = Log().logger
  78. # 创建解析器对象
  79. parser = myargparse(description='南网竞赛算法', add_help=False)
  80. if __name__ == '__main__':
  81. pass