run_case.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # -*- coding: UTF-8 -*-
  2. import numpy as np
  3. import os
  4. import sys
  5. import time
  6. from figure import Figure
  7. from data_process import data_process
  8. from data_features import data_features
  9. from logger import load_logger
  10. from config import myargparse
  11. from data_analyse import data_analyse
  12. frame = "keras"
  13. if frame == "keras":
  14. from model.model_keras import train, predict
  15. os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
  16. else:
  17. raise Exception("Wrong frame seletion")
  18. def main():
  19. parse = myargparse(discription="training config", add_help=False)
  20. opt = parse.parse_args_and_yaml()
  21. logger = load_logger(opt)
  22. try:
  23. np.random.seed(opt.Model["random_seed"])
  24. process = data_process(opt=opt)
  25. dfs = process.get_processed_data()
  26. features = data_features(opt=opt, mean=process.mean, std=process.std)
  27. if opt.do_train:
  28. train_X, valid_X, train_Y, valid_Y = features.get_train_data([dfs[0][:'2021/8/1'], dfs[1][:'2022/3/1']])
  29. train(opt, [train_X, train_Y, valid_X, valid_Y])
  30. if opt.do_predict:
  31. test_X, test_Y, df_Y = features.get_test_data([dfs[0]['2021/8/1':'2021/9/6'], dfs[1]['2022/3/1':'2022/4/4']])
  32. result = predict(opt, test_X) # 这里输出的是未还原的归一化预测数据
  33. analyse = data_analyse(opt, logger, process)
  34. analyse.predict_acc(result, df_Y)
  35. except Exception:
  36. logger.error("Run Error", exc_info=True)
  37. if __name__ == "__main__":
  38. import argparse
  39. # argparse方便于命令行下输入参数,可以根据需要增加更多
  40. # parser = argparse.ArgumentParser()
  41. # parser.add_argument("-t", "--do_train", default=False, type=bool, help="whether to train")
  42. # parser.add_argument("-p", "--do_predict", default=True, type=bool, help="whether to train")
  43. # parser.add_argument("-b", "--batch_size", default=64, type=int, help="batch size")
  44. # parser.add_argument("-e", "--epoch", default=20, type=int, help="epochs num")
  45. # args = parser.parse_args()
  46. # con = Config()
  47. # for key in dir(args): # dir(args) 函数获得args所有的属性
  48. # if not key.startswith("_"): # 去掉 args 自带属性,比如__name__等
  49. # setattr(con, key, getattr(args, key)) # 将属性值赋给Config
  50. main()