config.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2023/3/2 10:28
  4. # file: config.py
  5. # author: David
  6. # company: shenyang JY
  7. """
  8. 模型调参及系统功能配置
  9. """
  10. import os, threading
  11. import argparse, yaml
  12. class myargparse(argparse.ArgumentParser):
  13. _save_lock = threading.Lock()
  14. def __init__(self, description, add_help):
  15. super(myargparse, self).__init__(description=description, add_help=add_help)
  16. self.add_argument(
  17. '-c',
  18. '--config_yaml',
  19. default=
  20. 'config.yml',
  21. type=str,
  22. metavar='FILE',
  23. help='YAML config file specifying default arguments')
  24. self.add_argument(
  25. '-i',
  26. '--input_file',
  27. type=str,
  28. metavar='FILE',
  29. help='训练预测数据路径')
  30. self.add_argument(
  31. '-m',
  32. '--model_name',
  33. type=str,
  34. metavar='FILE',
  35. help='模型选择')
  36. def _parse_args_and_yaml(self):
  37. base_parser = argparse.ArgumentParser(add_help=False)
  38. base_parser.add_argument('-c', '--config_yaml', default='config.yml', type=str)
  39. given_configs, remaining = base_parser.parse_known_args()
  40. current_path = os.path.dirname(__file__)
  41. if given_configs.config_yaml:
  42. config_path = os.path.join(current_path, given_configs.config_yaml)
  43. with open(config_path, 'r', encoding='utf-8') as f:
  44. cfg = yaml.safe_load(f)
  45. self.set_defaults(**cfg)
  46. # defaults will have been overridden if config file specified.
  47. # opt = self.parse_args(remaining)
  48. opt = self.parse_args(remaining)
  49. # Cache the args as a text string to save them in the output dir later
  50. opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False)
  51. return opt, opt_text
  52. def parse_args_and_yaml(self):
  53. return self._parse_args_and_yaml()[0]
  54. def save_args_yml(self, opt):
  55. current_path = os.path.dirname(__file__)
  56. with myargparse._save_lock:
  57. file_path = os.path.join(current_path, 'config.yml')
  58. with open(file_path, mode='w', encoding='utf-8') as f:
  59. yaml.safe_dump(vars(opt), f)
  60. if __name__ == '__main__':
  61. args = myargparse(discription="场站端配置", add_help=False)
  62. opt = args.parse_args_and_yaml()