#!/usr/bin/env python # -*- coding: utf-8 -*- # time: 2023/3/17 14:46 # file: config.py # author: David # company: shenyang JY import yaml import argparse class myargparse(argparse.ArgumentParser): def __init__(self, discription, add_help): super(myargparse, self).__init__(description=discription, add_help=add_help) # default_config_parser = parser = argparse.ArgumentParser( # description='Training Config', add_help=False) self.add_argument( '-c', '--config_yaml', default= 'config.yml', type=str, metavar='FILE', help='YAML config file specifying default arguments') feature_columns = ['C_TIME', 'C_REAL_VALUE', 'C_FP_VALUE'] # feature_columns = list(range(1, 28)) label_columns = ['C_REAL_VALUE'] label_in_feature_index = (lambda x, y: [x.index(i) for i in y])(feature_columns, label_columns) # 因为feature不一定从0开始 # 在控制台可以指定的参数, yml中没有 self.add_argument('--feature_columns', type=list, default=feature_columns, help='要作为特征的列') self.add_argument('--label_columns', type=list, default=label_columns, help='要预测的列') self.add_argument('--label_in_feature_index', type=list, default=label_in_feature_index, help='标签在特征列的索引') self.add_argument('--input_size', type=int, default=len(feature_columns), help='输入维度') self.add_argument('--output_size', type=int, default=len(label_columns), help='输出维度') self.add_argument("--train_data_path", type=str, default=None,help='数据集地址') # train_data_path yml中有 # model_name 和 model_save_path 这两个参数根据yml中的参数拼接而成 self.add_argument('--model_name', type=str, default=None, help='模型名称') self.add_argument('--model_save_path', type=str, default=None, help='模型保存地址') def _init_dir(self, opt): import os, time # 在这里给opt赋值 opt.model_name = "model_" + opt.continue_flag + opt.used_frame + opt.model_postfix[opt.used_frame] opt.model_save_path = './checkpoint/' + opt.model_name + "/" if not os.path.exists(opt.model_save_path): os.makedirs(opt.model_save_path) # makedirs 递归创建目录 if not os.path.exists(opt.figure_save_path): os.mkdir(opt.figure_save_path) if opt.do_train and (opt.do_log_save_to_file or opt.do_train_visualized): cur_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) log_save_path = opt.log_save_path + cur_time + '_' + opt.used_frame + "/" os.makedirs(log_save_path) # YAML should override the argparser's content def _parse_args_and_yaml(self): given_configs, remaining = self.parse_known_args() if given_configs.config_yaml: with open(given_configs.config_yaml, 'r', encoding='utf-8') as f: cfg = yaml.safe_load(f) self.set_defaults(**cfg) # The main arg parser parses the rest of the args, the usual # defaults will have been overridden if config file specified. opt = self.parse_args(remaining) self._init_dir(opt) # Cache the args as a text string to save them in the output dir later opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False) return opt, opt_text def parse_args_and_yaml(self): return self._parse_args_and_yaml()[0] if __name__ == "__main__": # opt = _parse_args_and_yaml() pass