1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- # -*- coding: UTF-8 -*-
- import numpy as np
- np.random.seed(42)
- import os
- from data_process import data_process
- from data_features import data_features
- from logger import load_logger
- from config import myargparse
- from data_analyse import data_analyse
- frame = "keras"
- if frame == "keras":
- from model.model_keras_base import train, predict
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
- else:
- raise Exception("Wrong frame seletion")
- def main():
- parse = myargparse(discription="training config", add_help=False)
- opt = parse.parse_args_and_yaml()
- logger = load_logger(opt)
- try:
- process = data_process(opt=opt)
- features = data_features(opt=opt)
- if opt.do_train:
- data_train = process.get_train_data()
- train_X, valid_X, train_Y, valid_Y = features.get_train_data([data_train])
- print("训练的数据集有{}个点".format(len(train_X)))
- train_Y = np.array([y[:, 2] for y in train_Y])
- valid_Y = np.array([y[:, 2] for y in valid_Y])
- train(opt, [train_X, train_Y, valid_X, valid_Y])
- if opt.do_predict:
- data_test = process.get_test_data()
- # dfs = [group for name, group in data_test.groupby('label')]
- test_X, test_Y, df_Y = features.get_test_data([data_test])
- print("测试集有{}个点".format(len(test_X)))
- analyse = data_analyse(opt, logger)
- result = predict(opt, test_X)
- analyse.predict_acc(result, df_Y, predict_all=True)
- except Exception:
- logger.error("Run Error", exc_info=True)
- if __name__ == "__main__":
- import argparse
- # argparse方便于命令行下输入参数,可以根据需要增加更多
- # parser = argparse.ArgumentParser()
- # parser.add_argument("-t", "--do_train", default=False, type=bool, help="whether to train")
- # parser.add_argument("-p", "--do_predict", default=True, type=bool, help="whether to train")
- # parser.add_argument("-b", "--batch_size", default=64, type=int, help="batch size")
- # parser.add_argument("-e", "--epoch", default=20, type=int, help="epochs num")
- # args = parser.parse_args()
- # con = Config()
- # for key in dir(args): # dir(args) 函数获得args所有的属性
- # if not key.startswith("_"): # 去掉 args 自带属性,比如__name__等
- # setattr(con, key, getattr(args, key)) # 将属性值赋给Config
- main()
|