model_training_lstm.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import pandas as pd
  2. import numpy as np
  3. from pymongo import MongoClient
  4. from sklearn.model_selection import train_test_split
  5. from flask import Flask,request
  6. import time
  7. import traceback
  8. import logging
  9. from sklearn.preprocessing import MinMaxScaler
  10. from io import BytesIO
  11. import joblib
  12. from tensorflow.keras.models import Sequential
  13. from tensorflow.keras.layers import LSTM, Dense
  14. from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
  15. import matplotlib.pyplot as plt
  16. import tensorflow as tf
  17. app = Flask('model_training_lightgbm——service')
  18. def draw_loss(history):
  19. #绘制训练集和验证集损失
  20. plt.figure(figsize=(20, 8))
  21. plt.plot(history.history['loss'], label='Training Loss')
  22. plt.plot(history.history['val_loss'], label='Validation Loss')
  23. plt.title('Loss Curve')
  24. plt.xlabel('Epochs')
  25. plt.ylabel('Loss')
  26. plt.legend()
  27. plt.show()
  28. def get_data_from_mongo(args):
  29. mongodb_connection,mongodb_database,mongodb_read_table,timeBegin,timeEnd = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'],args['timeBegin'],args['timeEnd']
  30. client = MongoClient(mongodb_connection)
  31. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  32. db = client[mongodb_database]
  33. collection = db[mongodb_read_table] # 集合名称
  34. query = {"dateTime": {"$gte": timeBegin, "$lte": timeEnd}}
  35. cursor = collection.find(query)
  36. data = list(cursor)
  37. df = pd.DataFrame(data)
  38. # 4. 删除 _id 字段(可选)
  39. if '_id' in df.columns:
  40. df = df.drop(columns=['_id'])
  41. client.close()
  42. return df
  43. def insert_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,args):
  44. mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
  45. args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
  46. client = MongoClient(mongodb_connection)
  47. db = client[mongodb_database]
  48. collection = db[scaler_table] # 集合名称
  49. # Save the scalers in MongoDB as binary data
  50. collection.insert_one({
  51. "feature_scaler": feature_scaler_bytes.read(),
  52. "target_scaler": target_scaler_bytes.read()
  53. })
  54. print("model inserted successfully!")
  55. model_table = db[model_table]
  56. # 创建 BytesIO 缓冲区
  57. model_buffer = BytesIO()
  58. # 将模型保存为 HDF5 格式到内存 (BytesIO)
  59. model.save(model_buffer, save_format='h5')
  60. # 将指针移到缓冲区的起始位置
  61. model_buffer.seek(0)
  62. # 获取模型的二进制数据
  63. model_data = model_buffer.read()
  64. # 将模型保存到 MongoDB
  65. model_table.insert_one({
  66. "model_name": model_name,
  67. "model_data": model_data
  68. })
  69. print("模型成功保存到 MongoDB!")
  70. def rmse(y_true, y_pred):
  71. return tf.math.sqrt(tf.reduce_mean(tf.square(y_true - y_pred)))
  72. # 创建时间序列数据
  73. def create_sequences(data_features,data_target,time_steps):
  74. X, y = [], []
  75. if len(data_features)<time_steps:
  76. print("数据长度不能比时间步长小!")
  77. return np.array(X), np.array(y)
  78. else:
  79. for i in range(len(data_features) - time_steps+1):
  80. X.append(data_features[i:(i + time_steps)])
  81. if len(data_target)>0:
  82. y.append(data_target[i + time_steps -1])
  83. return np.array(X), np.array(y)
  84. def build_model(data, args):
  85. col_time, time_steps,features,target = args['col_time'], int(args['time_steps']), str_to_list(args['features']),args['target']
  86. train_data = data.fillna(method='ffill').fillna(method='bfill').sort_values(by=col_time)
  87. # X_train, X_test, y_train, y_test = process_data(df_clean, params)
  88. # 创建特征和目标的标准化器
  89. feature_scaler = MinMaxScaler(feature_range=(0, 1))
  90. target_scaler = MinMaxScaler(feature_range=(0, 1))
  91. # 标准化特征和目标
  92. scaled_features = feature_scaler.fit_transform(train_data[features])
  93. scaled_target = target_scaler.fit_transform(train_data[[target]])
  94. # 保存两个scaler
  95. feature_scaler_bytes = BytesIO()
  96. joblib.dump(feature_scaler, feature_scaler_bytes)
  97. feature_scaler_bytes.seek(0) # Reset pointer to the beginning of the byte stream
  98. target_scaler_bytes = BytesIO()
  99. joblib.dump(target_scaler, target_scaler_bytes)
  100. target_scaler_bytes.seek(0)
  101. X, y = create_sequences(scaled_features, scaled_target, time_steps)
  102. # 划分训练集和测试集
  103. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=43)
  104. # 构建 LSTM 模型
  105. model = Sequential()
  106. model.add(LSTM(units=50, return_sequences=False, input_shape=(time_steps, X_train.shape[2])))
  107. model.add(Dense(1)) # 输出单一值
  108. # 编译模型
  109. model.compile(optimizer='adam', loss='mean_squared_error')
  110. # 定义 EarlyStopping 和 ReduceLROnPlateau 回调
  111. early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)
  112. reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1)
  113. # 训练模型
  114. history = model.fit(X_train, y_train,
  115. epochs=100,
  116. batch_size=32,
  117. validation_data=(X_test, y_test),
  118. verbose=2,
  119. callbacks=[early_stopping, reduce_lr])
  120. draw_loss(history)
  121. return model,feature_scaler_bytes,target_scaler_bytes
  122. def str_to_list(arg):
  123. if arg == '':
  124. return []
  125. else:
  126. return arg.split(',')
  127. @app.route('/model_training_lstm', methods=['POST'])
  128. def model_training_lstm():
  129. # 获取程序开始时间
  130. start_time = time.time()
  131. result = {}
  132. success = 0
  133. print("Program starts execution!")
  134. try:
  135. args = request.values.to_dict()
  136. print('args',args)
  137. logger.info(args)
  138. power_df = get_data_from_mongo(args)
  139. model,feature_scaler_bytes,target_scaler_bytes = build_model(power_df,args)
  140. insert_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,args)
  141. success = 1
  142. except Exception as e:
  143. my_exception = traceback.format_exc()
  144. my_exception.replace("\n","\t")
  145. result['msg'] = my_exception
  146. end_time = time.time()
  147. result['success'] = success
  148. result['args'] = args
  149. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  150. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  151. print("Program execution ends!")
  152. return result
  153. if __name__=="__main__":
  154. print("Program starts execution!")
  155. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  156. logger = logging.getLogger("model_training_lightgbm log")
  157. from waitress import serve
  158. serve(app, host="0.0.0.0", port=10096)
  159. print("server start!")