nn_bp.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2024/5/6 13:25
  4. # file: time_series.py
  5. # author: David
  6. # company: shenyang JY
  7. import json, copy
  8. import numpy as np
  9. from flask import Flask, request
  10. import time
  11. import traceback
  12. import logging, argparse
  13. from sklearn.preprocessing import MinMaxScaler
  14. from io import BytesIO
  15. import joblib
  16. from tensorflow.keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, Reshape, Flatten
  17. from tensorflow.keras.models import Model, load_model
  18. from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
  19. from tensorflow.keras import optimizers, regularizers
  20. import tensorflow.keras.backend as K
  21. import tensorflow as tf
  22. from common.data_cleaning import cleaning
  23. from common.database_dml import *
  24. from common.processing_data_common import missing_features, str_to_list
  25. from data_processing.data_operation.data_handler import DataHandler
  26. from threading import Lock
  27. import time, yaml
  28. import random
  29. import matplotlib.pyplot as plt
  30. model_lock = Lock()
  31. from common.logs import Log
  32. logger = logging.getLogger()
  33. # logger = Log('models-processing').logger
  34. np.random.seed(42) # NumPy随机种子
  35. tf.random.set_random_seed(42) # TensorFlow随机种子
  36. app = Flask('nn_bp——service')
  37. with app.app_context():
  38. with open('../model_koi/bp.yaml', 'r', encoding='utf-8') as f:
  39. arguments = yaml.safe_load(f)
  40. dh = DataHandler(logger, arguments)
  41. def train_data_handler(data, opt):
  42. col_time, features, target = opt.col_time, opt.features, opt.target
  43. if 'is_limit' in data.columns:
  44. data = data[data['is_limit'] == False]
  45. # 清洗特征平均缺失率大于20%的天
  46. data = missing_features(data, features, col_time)
  47. train_data = data.sort_values(by=col_time).fillna(method='ffill').fillna(method='bfill')
  48. train_data = train_data.sort_values(by=col_time)
  49. # 对清洗完限电的数据进行特征预处理:1.空值异常值清洗 2.缺值补值
  50. train_data_cleaned = cleaning(train_data, 'nn_bp:features', logger, features)
  51. train_data = dh.fill_train_data(train_data_cleaned)
  52. # 创建特征和目标的标准化器
  53. train_scaler = MinMaxScaler(feature_range=(0, 1))
  54. # 标准化特征和目标
  55. scaled_train_data = train_scaler.fit_transform(train_data[features+[target]])
  56. # 保存两个scaler
  57. scaled_train_bytes = BytesIO()
  58. joblib.dump(scaled_train_data, scaled_train_bytes)
  59. scaled_train_bytes.seek(0) # Reset pointer to the beginning of the byte stream
  60. x_train, x_valid, y_train, y_valid = dh.get_train_data(scaled_train_data)
  61. return x_train, x_valid, y_train, y_valid, scaled_train_bytes
  62. def pre_data_handler(data, args):
  63. if 'is_limit' in data.columns:
  64. data = data[data['is_limit'] == False]
  65. features, time_steps, col_time, model_name,col_reserve = str_to_list(args['features']), int(args['time_steps']),args['col_time'],args['model_name'],str_to_list(args['col_reserve'])
  66. feature_scaler,target_scaler = get_scaler_model_from_mongo(args)
  67. pre_data = data.sort_values(by=col_time)
  68. scaled_features = feature_scaler.transform(pre_data[features])
  69. return scaled_features
  70. class BPHandler(object):
  71. def __init__(self, logger):
  72. self.logger = logger
  73. self.model = None
  74. def get_model(self, args):
  75. """
  76. 单例模式+线程锁,防止在异步加载时引发线程安全
  77. """
  78. try:
  79. with model_lock:
  80. # NPHandler.model = NPHandler.get_keras_model(opt)
  81. self.model = get_h5_model_from_mongo(args)
  82. except Exception as e:
  83. self.logger.info("加载模型权重失败:{}".format(e.args))
  84. @staticmethod
  85. def get_keras_model(opt):
  86. # db_loss = NorthEastLoss(opt)
  87. # south_loss = SouthLoss(opt)
  88. l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
  89. l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
  90. nwp_input = Input(shape=(opt.Model['time_step'], opt.Model['input_size_nwp']), name='nwp')
  91. env_input = Input(shape=(opt.Model['his_points'], opt.Model['input_size_env']), name='env')
  92. con1 = Conv1D(filters=64, kernel_size=1, strides=1, padding='valid', activation='relu',
  93. kernel_regularizer=l2_reg)(nwp_input)
  94. d1 = Dense(32, activation='relu', name='d1', kernel_regularizer=l1_reg)(con1)
  95. nwp = Dense(8, activation='relu', name='d2', kernel_regularizer=l1_reg)(d1)
  96. output = Dense(opt.Model['output_size'], name='d5')(nwp)
  97. model = Model([env_input, nwp_input], output)
  98. adam = optimizers.Adam(learning_rate=opt.Model['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-7,
  99. amsgrad=True)
  100. reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.01, patience=5, verbose=1)
  101. model.compile(loss='rmse', optimizer=adam)
  102. return model
  103. def train_init(self, opt):
  104. try:
  105. if opt.Model['add_train']:
  106. # 进行加强训练,支持修模
  107. base_train_model = get_h5_model_from_mongo(vars(opt))
  108. base_train_model.summary()
  109. self.logger.info("已加载加强训练基础模型")
  110. else:
  111. base_train_model = self.get_keras_model(opt)
  112. return base_train_model
  113. except Exception as e:
  114. self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
  115. def training(self, opt, train_and_valid_data):
  116. model = self.train_init(opt)
  117. tf.reset_default_graph() # 清除默认图
  118. train_x, train_y, valid_x, valid_y = train_and_valid_data
  119. print("----------", np.array(train_x[0]).shape)
  120. print("++++++++++", np.array(train_x[1]).shape)
  121. check_point = ModelCheckpoint(filepath='./var/' + 'fmi.h5', monitor='val_loss',
  122. save_best_only=True, mode='auto')
  123. early_stop = EarlyStopping(monitor='val_loss', patience=opt.Model['patience'], mode='auto')
  124. history = model.fit(train_x, train_y, batch_size=opt.Model['batch_size'], epochs=opt.Model['epoch'], verbose=2,
  125. validation_data=(valid_x, valid_y), callbacks=[check_point, early_stop], shuffle=False)
  126. loss = np.round(history.history['loss'], decimals=5)
  127. val_loss = np.round(history.history['val_loss'], decimals=5)
  128. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  129. self.logger.info("训练集损失函数为:{}".format(loss))
  130. self.logger.info("验证集损失函数为:{}".format(val_loss))
  131. return model
  132. def predict(self, test_X, batch_size=1):
  133. result = self.model.predict(test_X, batch_size=batch_size)
  134. self.logger.info("执行预测方法")
  135. return result
  136. @app.route('/model_training_bp', methods=['POST'])
  137. def model_training_bp():
  138. # 获取程序开始时间
  139. start_time = time.time()
  140. result = {}
  141. success = 0
  142. bp = BPHandler(logger)
  143. print("Program starts execution!")
  144. try:
  145. args_dict = request.values.to_dict()
  146. args = arguments.deepcopy()
  147. opt = argparse.Namespace(**args)
  148. logger.info(args_dict)
  149. train_data = get_data_from_mongo(args_dict)
  150. train_x, valid_x, train_y, valid_y, scaled_train_bytes = train_data_handler(train_data, opt)
  151. bp_model = bp.training(opt, [train_x, valid_x, train_y, valid_y])
  152. args_dict['params'] = json.dumps(args)
  153. args_dict['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
  154. insert_trained_model_into_mongo(bp_model, args_dict)
  155. insert_scaler_model_into_mongo(scaled_train_bytes, args_dict)
  156. success = 1
  157. except Exception as e:
  158. my_exception = traceback.format_exc()
  159. my_exception.replace("\n", "\t")
  160. result['msg'] = my_exception
  161. end_time = time.time()
  162. result['success'] = success
  163. result['args'] = args
  164. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  165. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  166. print("Program execution ends!")
  167. return result
  168. @app.route('/model_prediction_bp', methods=['POST'])
  169. def model_prediction_bp():
  170. # 获取程序开始时间
  171. start_time = time.time()
  172. result = {}
  173. success = 0
  174. bp = BPHandler(logger)
  175. print("Program starts execution!")
  176. try:
  177. params_dict = request.values.to_dict()
  178. args = arguments.deepcopy()
  179. args.update(params_dict)
  180. opt = argparse.Namespace(**args)
  181. print('args', args)
  182. logger.info(args)
  183. predict_data = get_data_from_mongo(args)
  184. scaled_features = pre_data_handler(predict_data, args)
  185. result = bp.predict(scaled_features, args)
  186. insert_data_into_mongo(result, args)
  187. success = 1
  188. except Exception as e:
  189. my_exception = traceback.format_exc()
  190. my_exception.replace("\n", "\t")
  191. result['msg'] = my_exception
  192. end_time = time.time()
  193. result['success'] = success
  194. result['args'] = args
  195. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  196. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  197. print("Program execution ends!")
  198. return result
  199. if __name__ == "__main__":
  200. print("Program starts execution!")
  201. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  202. logger = logging.getLogger("model_training_bp log")
  203. from waitress import serve
  204. # serve(app, host="0.0.0.0", port=10103, threads=4)
  205. print("server start!")
  206. bp = BPHandler(logger)
  207. args = copy.deepcopy(bp)
  208. opt = argparse.Namespace(**arguments)
  209. logger.info(args)
  210. args_dict = {"mongodb_database": 'david_test', 'scaler_table': 'j00083_scaler', 'model_name': 'bp1.0.test',
  211. 'model_table': 'j00083_model', 'mongodb_read_table': 'j00083'}
  212. train_data = get_data_from_mongo(args_dict)
  213. train_x, valid_x, train_y, valid_y, scaled_train_bytes = train_data_handler(train_data, opt)
  214. bp_model = bp.training(opt, [train_x, valid_x, train_y, valid_y])
  215. insert_trained_model_into_mongo(bp_model, args_dict)
  216. insert_scaler_model_into_mongo(scaled_train_bytes, args)