#!/usr/bin/env python # -*- coding: utf-8 -*- # time: 2023/3/2 10:28 # file: config.py # author: David # company: shenyang JY """ 模型调参及系统功能配置 """ import os, threading import argparse, yaml class myargparse(argparse.ArgumentParser): _save_lock = threading.Lock() def __init__(self, description, add_help): super(myargparse, self).__init__(description=description, add_help=add_help) self.add_argument( '-c', '--config_yaml', default= 'config.yml', type=str, metavar='FILE', help='YAML config file specifying default arguments') self.add_argument( '-n', '--neu_yaml', default= 'neu.yml', type=str, metavar='FILE', help='YAML config file specifying default arguments') self.add_argument( '-i', '--input_file', type=str, metavar='FILE', help='训练预测数据路径') self.add_argument( '-m', '--model_name', type=str, metavar='FILE', help='模型选择') def _parse_args_and_yaml(self): base_parser = argparse.ArgumentParser(add_help=False) base_parser.add_argument('-c', '--config_yaml', default='config.yml', type=str) base_parser.add_argument('-n', '--neu_yaml', default='neu.yml', type=str) given_configs, remaining = base_parser.parse_known_args() current_path = os.path.dirname(os.path.dirname(__file__)) if given_configs.config_yaml: config_path = os.path.join(current_path, 'common', given_configs.config_yaml) with open(config_path, 'r', encoding='utf-8') as f: cfg = yaml.safe_load(f) self.set_defaults(**cfg) if given_configs.neu_yaml: config_path = os.path.join(current_path, 'common', given_configs.neu_yaml) with open(config_path, 'r', encoding='utf-8') as f: model_cfg = yaml.safe_load(f) self.set_defaults(**model_cfg) # defaults will have been overridden if config file specified. # opt = self.parse_args(remaining) opt = self.parse_args(remaining) # Cache the args as a text string to save them in the output dir later opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False) return opt, opt_text def parse_args_and_yaml(self): return self._parse_args_and_yaml()[0] def save_args_yml(self, opt): current_path = os.path.dirname(__file__) with myargparse._save_lock: file_path = os.path.join(current_path, 'config.yml') with open(file_path, mode='w', encoding='utf-8') as f: yaml.safe_dump(vars(opt), f) if __name__ == '__main__': args = myargparse(discription="场站端配置", add_help=False) opt = args.parse_args_and_yaml()