123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # time: 2023/3/2 10:28
- # file: config.py
- # author: David
- # company: shenyang JY
- """
- 模型调参及系统功能配置
- """
- import os
- import argparse, yaml
- class myargparse(argparse.ArgumentParser):
- def __init__(self, discription, add_help):
- super(myargparse, self).__init__(description=discription, add_help=add_help)
- # default_config_parser = parser = argparse.ArgumentParser(
- # description='Training Config', add_help=False)
- self.add_argument(
- '-c',
- '--config_yaml',
- default=
- 'config.yml',
- type=str,
- metavar='FILE',
- help='YAML config file specifying default arguments')
- # self.add_argument(
- # '-f',
- # '--feature_yaml',
- # default='feature.yml',
- # type=str,
- # metavar='FILE',
- # help='YAML feature norm file for clustering'
- # )
- def _parse_args_and_yaml(self):
- given_configs, remaining = self.parse_known_args()
- current_path = os.path.dirname(__file__)
- if given_configs.config_yaml:
- with open(current_path + '/' + given_configs.config_yaml, 'r', encoding='utf-8') as f:
- cfg = yaml.safe_load(f)
- self.set_defaults(**cfg)
- # if given_configs.feature_yaml:
- # with open(current_path + '/' + given_configs.feature_yaml, 'r', encoding='utf-8') as f:
- # cfg = yaml.safe_load(f)
- # self.set_defaults(**cfg)
- # The main arg parser parses the rest of the args, the usual
- # defaults will have been overridden if config file specified.
- opt = self.parse_args(remaining)
- # Cache the args as a text string to save them in the output dir later
- opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False)
- # print("opt", opt)
- return opt, opt_text
- def parse_args_and_yaml(self):
- return self._parse_args_and_yaml()[0]
- def save_args_yml(self, opt, isdict=False):
- current_path = os.path.dirname(__file__)
- if isdict:
- with open(current_path + '/' + 'config.yml', mode='w', encoding='utf-8') as f:
- yaml.safe_dump(opt, f)
- else:
- with open(current_path + '/' + 'config.yml', mode='w', encoding='utf-8') as f:
- yaml.safe_dump(vars(opt), f)
- if __name__ == '__main__':
- args = myargparse(discription="场站端配置", add_help=False)
- opt = args.parse_args_and_yaml()
|