model_prediction_lstm.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import pandas as pd
  2. from pymongo import MongoClient
  3. from flask import Flask,request
  4. import time
  5. import logging
  6. import traceback
  7. from io import BytesIO
  8. import joblib
  9. import numpy as np
  10. import h5py
  11. import tensorflow as tf
  12. from itertools import chain
  13. app = Flask('model_prediction_lstm——service')
  14. def get_data_from_mongo(args):
  15. mongodb_connection,mongodb_database,mongodb_read_table,timeBegin,timeEnd = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'],args['timeBegin'],args['timeEnd']
  16. client = MongoClient(mongodb_connection)
  17. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  18. db = client[mongodb_database]
  19. collection = db[mongodb_read_table] # 集合名称
  20. query = {"dateTime": {"$gte": timeBegin, "$lte": timeEnd}}
  21. cursor = collection.find(query)
  22. data = list(cursor)
  23. df = pd.DataFrame(data)
  24. # 4. 删除 _id 字段(可选)
  25. if '_id' in df.columns:
  26. df = df.drop(columns=['_id'])
  27. client.close()
  28. return df
  29. def insert_data_into_mongo(res_df,args):
  30. mongodb_connection,mongodb_database,mongodb_write_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_write_table']
  31. client = MongoClient(mongodb_connection)
  32. db = client[mongodb_database]
  33. if mongodb_write_table in db.list_collection_names():
  34. db[mongodb_write_table].drop()
  35. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  36. collection = db[mongodb_write_table] # 集合名称
  37. # 将 DataFrame 转为字典格式
  38. data_dict = res_df.to_dict("records") # 每一行作为一个字典
  39. # 插入到 MongoDB
  40. collection.insert_many(data_dict)
  41. print("data inserted successfully!")
  42. def get_model_from_mongo(args):
  43. mongodb_connection,mongodb_database,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['model_table'],args['model_name']
  44. client = MongoClient(mongodb_connection)
  45. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  46. db = client[mongodb_database]
  47. collection = db[model_table] # 集合名称
  48. # 查询 MongoDB 获取模型数据
  49. model_doc = collection.find_one({"model_name": model_name})
  50. if model_doc:
  51. model_data = model_doc['model_data'] # 获取模型的二进制数据
  52. # 将二进制数据加载到 BytesIO 缓冲区
  53. model_buffer = BytesIO(model_data)
  54. # 从缓冲区加载模型
  55. # 使用 h5py 和 BytesIO 从内存中加载模型
  56. with h5py.File(model_buffer, 'r') as f:
  57. model = tf.keras.models.load_model(f)
  58. print(f"{model_name}模型成功从 MongoDB 加载!")
  59. client.close()
  60. return model
  61. else:
  62. print(f"未找到model_name为 {model_name} 的模型。")
  63. client.close()
  64. return None
  65. # 创建时间序列数据
  66. def create_sequences(data_features,data_target,time_steps):
  67. X, y = [], []
  68. if len(data_features)<time_steps:
  69. print("数据长度不能比时间步长小!")
  70. return np.array(X), np.array(y)
  71. else:
  72. for i in range(len(data_features) - time_steps+1):
  73. X.append(data_features[i:(i + time_steps)])
  74. if len(data_target)>0:
  75. y.append(data_target[i + time_steps -1])
  76. return np.array(X), np.array(y)
  77. def model_prediction(df,args):
  78. mongodb_connection, mongodb_database, scaler_table, features, time_steps = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
  79. args['mongodb_database'], args['scaler_table'],args['features'],args['time_steps'])
  80. client = MongoClient(mongodb_connection)
  81. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  82. db = client[mongodb_database]
  83. collection = db[scaler_table] # 集合名称
  84. # Retrieve the scalers from MongoDB
  85. scaler_doc = collection.find_one()
  86. # Deserialize the scalers
  87. feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"])
  88. feature_scaler = joblib.load(feature_scaler_bytes)
  89. target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
  90. target_scaler = joblib.load(target_scaler_bytes)
  91. scaled_features = feature_scaler.transform(df[features])
  92. X_predict, _ = create_sequences(scaled_features, [], time_steps)
  93. # 加载模型时传入自定义损失函数
  94. # model = load_model(f'{farmId}_model.h5', custom_objects={'rmse': rmse})
  95. model = get_model_from_mongo(args)
  96. y_predict = list(chain.from_iterable(target_scaler.inverse_transform([model.predict(X_predict).flatten()])))
  97. result = df[-len(y_predict):]
  98. result['predict'] = y_predict
  99. return result
  100. @app.route('/model_prediction_lstm', methods=['POST'])
  101. def model_prediction_lstm():
  102. # 获取程序开始时间
  103. start_time = time.time()
  104. result = {}
  105. success = 0
  106. print("Program starts execution!")
  107. try:
  108. args = request.values.to_dict()
  109. print('args',args)
  110. logger.info(args)
  111. power_df = get_data_from_mongo(args)
  112. model = model_prediction(power_df,args)
  113. insert_data_into_mongo(model,args)
  114. success = 1
  115. except Exception as e:
  116. my_exception = traceback.format_exc()
  117. my_exception.replace("\n","\t")
  118. result['msg'] = my_exception
  119. end_time = time.time()
  120. result['success'] = success
  121. result['args'] = args
  122. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  123. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  124. print("Program execution ends!")
  125. return result
  126. if __name__=="__main__":
  127. print("Program starts execution!")
  128. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  129. logger = logging.getLogger("model_prediction_lstm log")
  130. from waitress import serve
  131. serve(app, host="0.0.0.0", port=10097)
  132. print("server start!")