config.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2023/3/17 14:46
  4. # file: config.py
  5. # author: David
  6. # company: shenyang JY
  7. import yaml
  8. import argparse
  9. class myargparse(argparse.ArgumentParser):
  10. def __init__(self, discription, add_help):
  11. super(myargparse, self).__init__(description=discription, add_help=add_help)
  12. # default_config_parser = parser = argparse.ArgumentParser(
  13. # description='Training Config', add_help=False)
  14. self.add_argument(
  15. '-c',
  16. '--config_yaml',
  17. default=
  18. 'config.yml',
  19. type=str,
  20. metavar='FILE',
  21. help='YAML config file specifying default arguments')
  22. # feature_columns = list(range(1, 28))
  23. label_columns = ['C_REAL_VALUE']
  24. # label_in_feature_index = (lambda x, y: [x.index(i) for i in y])(feature_columns, label_columns) # 因为feature不一定从0开始
  25. # 在控制台可以指定的参数, yml中没有
  26. self.add_argument('--feature_columns', type=list, default=None, help='要作为特征的列')
  27. self.add_argument('--label_columns', type=list, default=label_columns, help='要预测的列')
  28. self.add_argument('--label_in_feature_index', type=list, default=None, help='标签在特征列的索引')
  29. self.add_argument('--input_size', type=int, default=0, help='输入维度')
  30. self.add_argument('--output_size', type=int, default=len(label_columns), help='输出维度')
  31. self.add_argument("--train_data_path", type=str, default=None,help='数据集地址') # train_data_path yml中有
  32. # model_name 和 model_save_path 这两个参数根据yml中的参数拼接而成
  33. self.add_argument('--model_name', type=str, default=None, help='模型名称')
  34. self.add_argument('--model_save_path', type=str, default=None, help='模型保存地址')
  35. def _init_dir(self, opt):
  36. import os, time
  37. # 在这里给opt赋值
  38. opt.model_name = "model_" + opt.continue_flag + opt.used_frame + opt.model_postfix[opt.used_frame]
  39. opt.model_save_path = './checkpoint/' + opt.model_name + "/"
  40. if not os.path.exists(opt.model_save_path):
  41. os.makedirs(opt.model_save_path) # makedirs 递归创建目录
  42. if not os.path.exists(opt.figure_save_path):
  43. os.mkdir(opt.figure_save_path)
  44. if opt.do_train and (opt.do_log_save_to_file or opt.do_train_visualized):
  45. cur_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
  46. log_save_path = opt.log_save_path + cur_time + '_' + opt.used_frame + "/"
  47. os.makedirs(log_save_path)
  48. # YAML should override the argparser's content
  49. def _parse_args_and_yaml(self):
  50. given_configs, remaining = self.parse_known_args()
  51. if given_configs.config_yaml:
  52. with open(given_configs.config_yaml, 'r', encoding='utf-8') as f:
  53. cfg = yaml.safe_load(f)
  54. self.set_defaults(**cfg)
  55. # The main arg parser parses the rest of the args, the usual
  56. # defaults will have been overridden if config file specified.
  57. opt = self.parse_args(remaining)
  58. self._init_dir(opt)
  59. # Cache the args as a text string to save them in the output dir later
  60. opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False)
  61. return opt, opt_text
  62. def parse_args_and_yaml(self):
  63. return self._parse_args_and_yaml()[0]
  64. if __name__ == "__main__":
  65. # opt = _parse_args_and_yaml()
  66. pass