config.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2023/3/17 14:46
  4. # file: config.py
  5. # author: David
  6. # company: shenyang JY
  7. import yaml
  8. import argparse
  9. class myargparse(argparse.ArgumentParser):
  10. def __init__(self, discription, add_help):
  11. super(myargparse, self).__init__(description=discription, add_help=add_help)
  12. self.add_argument('-c', '--config_yaml',default='config_xiushui.yml', type=str, metavar='FILE')
  13. self.add_argument('--norm_yaml', default='./data/xiushui/xiushui15/norm.yaml', type=str, metavar='FILE')
  14. self.add_argument('--input_size', type=int, default=0, help='输入维度')
  15. self.add_argument('--input_size_lstm', type=int, default=0, help='输入维度')
  16. self.add_argument('--input_size_cnn', type=int, default=0, help='输入维度')
  17. self.add_argument('--output_size', type=int, default=16, help='输出维度') # 16个点
  18. # model_name 和 model_save_path 这两个参数根据yml中的参数拼接而成
  19. self.add_argument('--model_prefix', type=str, default=None, help='模型名称')
  20. self.add_argument('--save_name', type=str, default=None, help='保存名称')
  21. self.add_argument('--model_save_path', type=str, default=None, help='模型保存地址')
  22. self.add_argument('--columns_lstm', type=list, default=None, help='lstm列名')
  23. self.add_argument('--columns_cnn', type=list, default=None, help='cnn列名')
  24. def _init_dir(self, opt):
  25. import os, time
  26. # 在这里给opt赋值
  27. opt.model_prefix = "model_" + opt.continue_flag
  28. opt.model_save_path = './checkpoint/' + opt.model_prefix + "/"
  29. opt.save_name = "model_" + opt.save_frame + opt.model_postfix['keras']
  30. if not os.path.exists(opt.model_save_path):
  31. os.makedirs(opt.model_save_path) # makedirs 递归创建目录
  32. if not os.path.exists(opt.figure_save_path):
  33. os.mkdir(opt.figure_save_path)
  34. if opt.do_train and (opt.do_log_save_to_file or opt.do_train_visualized):
  35. cur_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
  36. log_save_path = opt.log_save_path + cur_time + '_' + opt.used_frame + "/"
  37. os.makedirs(log_save_path)
  38. # YAML should override the argparser's content
  39. def _parse_args_and_yaml(self):
  40. given_configs, remaining = self.parse_known_args()
  41. if given_configs.config_yaml:
  42. with open(given_configs.config_yaml, 'r', encoding='utf-8') as f:
  43. cfg = yaml.safe_load(f)
  44. self.set_defaults(**cfg)
  45. if given_configs.norm_yaml:
  46. with open(given_configs.norm_yaml, 'r', encoding='utf-8') as f:
  47. cfg = yaml.safe_load(f)
  48. print("归一化参数:", cfg)
  49. self.set_defaults(**cfg)
  50. # The main arg parser parses the rest of the args, the usual
  51. # defaults will have been overridden if config file specified.
  52. opt = self.parse_args(remaining)
  53. self._init_dir(opt)
  54. # Cache the args as a text string to save them in the output dir later
  55. opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False)
  56. return opt, opt_text
  57. def parse_args_and_yaml(self):
  58. return self._parse_args_and_yaml()[0]
  59. if __name__ == "__main__":
  60. # opt = _parse_args_and_yaml()
  61. pass