# -*- coding: UTF-8 -*- import numpy as np np.random.seed(42) import os from data_process import data_process from data_features import data_features from logger import load_logger from config import myargparse from data_analyse import data_analyse frame = "keras" if frame == "keras": from model.model_keras_fenqu import train, predict os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3' else: raise Exception("Wrong frame seletion") def main(): parse = myargparse(discription="training config", add_help=False) opt = parse.parse_args_and_yaml() logger = load_logger(opt) try: process = data_process(opt=opt) features = data_features(opt=opt) if opt.do_train: data_train = process.get_train_data() train_X, valid_X, train_Y, valid_Y = features.get_train_data([data_train]) print("训练的数据集有{}个点".format(len(train_X[0]))) # train_Y = [np.array([y[:, 0] for y in train_Y])] # valid_Y = [np.array([y[:, 0] for y in valid_Y])] # train(opt, [train_X, train_Y, valid_X, valid_Y]) train_Y = [np.array([y[:, 0] for y in train_Y]), np.array([y[:, 1] for y in train_Y])] valid_Y = [np.array([y[:, 0] for y in valid_Y]), np.array([y[:, 1] for y in valid_Y])] train(opt, [train_X, train_Y, valid_X, valid_Y]) if opt.do_predict: data_test = process.get_test_data() test_X, test_Y, df_Y = features.get_test_data([data_test]) print("测试集有{}个点".format(len(test_X))) result = predict(opt, test_X) # 这里输出的是未还原的归一化预测数据 analyse = data_analyse(opt, logger) # analyse.predict_acc(result, df_Y, predict_all=True) analyse.predict_acc(result, df_Y, predict_all=False) except Exception: logger.error("Run Error", exc_info=True) if __name__ == "__main__": import argparse # argparse方便于命令行下输入参数,可以根据需要增加更多 # parser = argparse.ArgumentParser() # parser.add_argument("-t", "--do_train", default=False, type=bool, help="whether to train") # parser.add_argument("-p", "--do_predict", default=True, type=bool, help="whether to train") # parser.add_argument("-b", "--batch_size", default=64, type=int, help="batch size") # parser.add_argument("-e", "--epoch", default=20, type=int, help="epochs num") # args = parser.parse_args() # con = Config() # for key in dir(args): # dir(args) 函数获得args所有的属性 # if not key.startswith("_"): # 去掉 args 自带属性,比如__name__等 # setattr(con, key, getattr(args, key)) # 将属性值赋给Config main()