run_case_分区.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # -*- coding: UTF-8 -*-
  2. import numpy as np
  3. np.random.seed(42)
  4. import os
  5. from data_process import data_process
  6. from data_features import data_features
  7. from logger import load_logger
  8. from config import myargparse
  9. from data_analyse import data_analyse
  10. frame = "keras"
  11. if frame == "keras":
  12. from model.model_keras_fenqu import train, predict
  13. os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
  14. else:
  15. raise Exception("Wrong frame seletion")
  16. def main():
  17. parse = myargparse(discription="training config", add_help=False)
  18. opt = parse.parse_args_and_yaml()
  19. logger = load_logger(opt)
  20. try:
  21. process = data_process(opt=opt)
  22. features = data_features(opt=opt)
  23. if opt.do_train:
  24. data_train = process.get_train_data()
  25. train_X, valid_X, train_Y, valid_Y = features.get_train_data([data_train])
  26. print("训练的数据集有{}个点".format(len(train_X[0])))
  27. # train_Y = [np.array([y[:, 0] for y in train_Y])]
  28. # valid_Y = [np.array([y[:, 0] for y in valid_Y])]
  29. # train(opt, [train_X, train_Y, valid_X, valid_Y])
  30. train_Y = [np.array([y[:, 0] for y in train_Y]), np.array([y[:, 1] for y in train_Y])]
  31. valid_Y = [np.array([y[:, 0] for y in valid_Y]), np.array([y[:, 1] for y in valid_Y])]
  32. train(opt, [train_X, train_Y, valid_X, valid_Y])
  33. if opt.do_predict:
  34. data_test = process.get_test_data()
  35. test_X, test_Y, df_Y = features.get_test_data([data_test])
  36. print("测试集有{}个点".format(len(test_X)))
  37. result = predict(opt, test_X) # 这里输出的是未还原的归一化预测数据
  38. analyse = data_analyse(opt, logger)
  39. # analyse.predict_acc(result, df_Y, predict_all=True)
  40. analyse.predict_acc(result, df_Y, predict_all=False)
  41. except Exception:
  42. logger.error("Run Error", exc_info=True)
  43. if __name__ == "__main__":
  44. import argparse
  45. # argparse方便于命令行下输入参数,可以根据需要增加更多
  46. # parser = argparse.ArgumentParser()
  47. # parser.add_argument("-t", "--do_train", default=False, type=bool, help="whether to train")
  48. # parser.add_argument("-p", "--do_predict", default=True, type=bool, help="whether to train")
  49. # parser.add_argument("-b", "--batch_size", default=64, type=int, help="batch size")
  50. # parser.add_argument("-e", "--epoch", default=20, type=int, help="epochs num")
  51. # args = parser.parse_args()
  52. # con = Config()
  53. # for key in dir(args): # dir(args) 函数获得args所有的属性
  54. # if not key.startswith("_"): # 去掉 args 自带属性,比如__name__等
  55. # setattr(con, key, getattr(args, key)) # 将属性值赋给Config
  56. main()