import pandas as pd from pymongo import MongoClient from flask import Flask,request import time import logging import traceback from io import BytesIO import joblib import numpy as np import h5py import tensorflow as tf from itertools import chain app = Flask('model_prediction_lstm——service') def get_data_from_mongo(args): mongodb_connection,mongodb_database,mongodb_read_table,timeBegin,timeEnd = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'],args['timeBegin'],args['timeEnd'] client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] collection = db[mongodb_read_table] # 集合名称 query = {"dateTime": {"$gte": timeBegin, "$lte": timeEnd}} cursor = collection.find(query) data = list(cursor) df = pd.DataFrame(data) # 4. 删除 _id 字段(可选) if '_id' in df.columns: df = df.drop(columns=['_id']) client.close() return df def insert_data_into_mongo(res_df,args): mongodb_connection,mongodb_database,mongodb_write_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_write_table'] client = MongoClient(mongodb_connection) db = client[mongodb_database] if mongodb_write_table in db.list_collection_names(): db[mongodb_write_table].drop() print(f"Collection '{mongodb_write_table} already exist, deleted successfully!") collection = db[mongodb_write_table] # 集合名称 # 将 DataFrame 转为字典格式 data_dict = res_df.to_dict("records") # 每一行作为一个字典 # 插入到 MongoDB collection.insert_many(data_dict) print("data inserted successfully!") def get_model_from_mongo(args): mongodb_connection,mongodb_database,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['model_table'],args['model_name'] client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] collection = db[model_table] # 集合名称 # 查询 MongoDB 获取模型数据 model_doc = collection.find_one({"model_name": model_name}) if model_doc: model_data = model_doc['model_data'] # 获取模型的二进制数据 # 将二进制数据加载到 BytesIO 缓冲区 model_buffer = BytesIO(model_data) # 从缓冲区加载模型 # 使用 h5py 和 BytesIO 从内存中加载模型 with h5py.File(model_buffer, 'r') as f: model = tf.keras.models.load_model(f) print(f"{model_name}模型成功从 MongoDB 加载!") client.close() return model else: print(f"未找到model_name为 {model_name} 的模型。") client.close() return None # 创建时间序列数据 def create_sequences(data_features,data_target,time_steps): X, y = [], [] if len(data_features)0: y.append(data_target[i + time_steps -1]) return np.array(X), np.array(y) def model_prediction(df,args): mongodb_connection, mongodb_database, scaler_table, features, time_steps, col_time = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'], args['scaler_table'], str_to_list(args['features']),int(args['time_steps']),args['col_time']) client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] collection = db[scaler_table] # 集合名称 # Retrieve the scalers from MongoDB scaler_doc = collection.find_one() # Deserialize the scalers feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"]) feature_scaler = joblib.load(feature_scaler_bytes) target_scaler_bytes = BytesIO(scaler_doc["target_scaler"]) target_scaler = joblib.load(target_scaler_bytes) df = df.fillna(method='ffill').fillna(method='bfill').sort_values(by=col_time) scaled_features = feature_scaler.transform(df[features]) X_predict, _ = create_sequences(scaled_features, [], time_steps) # 加载模型时传入自定义损失函数 # model = load_model(f'{farmId}_model.h5', custom_objects={'rmse': rmse}) model = get_model_from_mongo(args) y_predict = list(chain.from_iterable(target_scaler.inverse_transform([model.predict(X_predict).flatten()]))) result = df[-len(y_predict):] result['predict'] = y_predict return result def str_to_list(arg): if arg == '': return [] else: return arg.split(',') @app.route('/model_prediction_lstm', methods=['POST']) def model_prediction_lstm(): # 获取程序开始时间 start_time = time.time() result = {} success = 0 print("Program starts execution!") try: args = request.values.to_dict() print('args',args) logger.info(args) power_df = get_data_from_mongo(args) model = model_prediction(power_df,args) insert_data_into_mongo(model,args) success = 1 except Exception as e: my_exception = traceback.format_exc() my_exception.replace("\n","\t") result['msg'] = my_exception end_time = time.time() result['success'] = success result['args'] = args result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)) result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time)) print("Program execution ends!") return result if __name__=="__main__": print("Program starts execution!") logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("model_prediction_lstm log") from waitress import serve serve(app, host="0.0.0.0", port=10097) print("server start!")