12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # time: 2023/3/2 10:28
- # file: config.py
- # author: David
- # company: shenyang JY
- """
- 模型调参及系统功能配置
- """
- import os, threading
- import argparse, yaml
- from app.common.logs import Log
- class myargparse(argparse.ArgumentParser):
- _save_lock = threading.Lock()
- def __init__(self, description, add_help):
- super(myargparse, self).__init__(description=description, add_help=add_help)
- self.add_argument(
- '-c',
- '--config_yaml',
- default=
- 'config.yml',
- type=str,
- metavar='FILE',
- help='YAML config file specifying default arguments')
- self.add_argument(
- '-n',
- '--neu_yaml',
- default=
- 'neu.yml',
- type=str,
- metavar='FILE',
- help='YAML config file specifying default arguments')
- self.add_argument(
- '-i',
- '--input_file',
- type=str,
- metavar='FILE',
- help='训练预测数据路径')
- self.add_argument(
- '-m',
- '--model_name',
- type=str,
- metavar='FILE',
- help='模型选择')
- def _parse_args_and_yaml(self):
- base_parser = argparse.ArgumentParser(add_help=False)
- base_parser.add_argument('-c', '--config_yaml', default='config.yml', type=str)
- base_parser.add_argument('-n', '--neu_yaml', default='neu.yml', type=str)
- given_configs, remaining = base_parser.parse_known_args()
- current_path = os.path.dirname(os.path.dirname(__file__))
- if given_configs.config_yaml:
- config_path = os.path.join(current_path, 'common', given_configs.config_yaml)
- with open(config_path, 'r', encoding='utf-8') as f:
- cfg = yaml.safe_load(f)
- self.set_defaults(**cfg)
- if given_configs.neu_yaml:
- config_path = os.path.join(current_path, 'common', given_configs.neu_yaml)
- with open(config_path, 'r', encoding='utf-8') as f:
- model_cfg = yaml.safe_load(f)
- self.set_defaults(**model_cfg)
- # defaults will have been overridden if config file specified.
- # opt = self.parse_args(remaining)
- opt = self.parse_args(remaining)
- # Cache the args as a text string to save them in the output dir later
- opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False)
- return opt, opt_text
- def parse_args_and_yaml(self):
- return self._parse_args_and_yaml()[0]
- def save_args_yml(self, opt):
- current_path = os.path.dirname(__file__)
- with myargparse._save_lock:
- file_path = os.path.join(current_path, 'config.yml')
- with open(file_path, mode='w', encoding='utf-8') as f:
- yaml.safe_dump(vars(opt), f)
- # -------------- 设置全局对象 --------------
- # 创建日志对象
- logger = Log().logger
- # 创建解析器对象
- parser = myargparse(description='南网竞赛算法', add_help=False)
- if __name__ == '__main__':
- pass
|