model_prediction_lstm.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from flask import Flask,request
  2. import time
  3. import logging
  4. import traceback
  5. import numpy as np
  6. from itertools import chain
  7. from common.database_dml import get_data_from_mongo,insert_data_into_mongo,get_h5_model_from_mongo,get_scaler_model_from_mongo
  8. from common.processing_data_common import str_to_list
  9. from common.alert import send_message
  10. from datetime import date, timedelta
  11. import pandas as pd
  12. app = Flask('model_prediction_lstm——service')
  13. # 创建时间序列数据
  14. def create_sequences(data_features,data_target,time_steps):
  15. X, y = [], []
  16. if len(data_features)<time_steps:
  17. print("数据长度不能比时间步长小!")
  18. return np.array(X), np.array(y)
  19. else:
  20. for i in range(len(data_features) - time_steps+1):
  21. X.append(data_features[i:(i + time_steps)])
  22. if len(data_target)>0:
  23. y.append(data_target[i + time_steps -1])
  24. return np.array(X), np.array(y)
  25. def forecast_data_distribution(pre_data,args):
  26. features, time_steps, col_time, model_name = str_to_list(args['features']), int(args['time_steps']), \
  27. args['col_time'], args['model_name'],
  28. feature_scaler, target_scaler = get_scaler_model_from_mongo(args)
  29. tomorrow = (date.today() + timedelta(days=1)).strftime('%Y-%m-%d')
  30. field_mapping = {'clearsky_ghi': 'clearskyGhi', 'dni_calcd': 'dniCalcd','surface_pressure':'surfacePressure',}
  31. # 根据字段映射重命名列
  32. pre_data = pre_data.rename(columns=field_mapping)
  33. diff = set(features) - set(pre_data.columns)
  34. if len(pre_data)==0:
  35. send_message('lstm预测组件', args['farmId'], '请注意:获取NWP数据为空,预测文件无法生成!')
  36. result = pd.DataFrame({col_time:[],'farm_id':[],'power_forecast':[]})
  37. elif len(diff)>0:
  38. send_message('lstm预测组件', args['farmId'], f'NWP特征列缺失!features:{diff}')
  39. result = pre_data[['date_time', 'farm_id', 'power_forecast']]
  40. elif len(pre_data[pre_data[col_time].str.contains(tomorrow)])<96:
  41. send_message('lstm预测组件', args['farmId'], "日前数据记录缺失,不足96条,用DQ代替并补值!")
  42. start_time = pre_data[col_time].min()
  43. end_time = pre_data[col_time].max()
  44. date_range = pd.date_range(start=start_time, end=end_time, freq='15T').strftime('%Y-%m-%d %H:%M:%S').tolist()
  45. df_date = pd.DataFrame({col_time:date_range})
  46. result = pd.merge(df_date,pre_data,how='left',on=col_time).sort_values(by=col_time).fillna(method='ffill').fillna(method='bfill')
  47. result = result[['date_time', 'farm_id', 'power_forecast']]
  48. else:
  49. df = pre_data.sort_values(by=col_time).fillna(method='ffill').fillna(method='bfill')
  50. scaled_features = feature_scaler.transform(df[features])
  51. X_predict, _ = create_sequences(scaled_features, [], time_steps)
  52. model = get_h5_model_from_mongo(args)
  53. y_predict = list(chain.from_iterable(target_scaler.inverse_transform([model.predict(X_predict).flatten()])))
  54. result = df[-len(y_predict):]
  55. result['power_forecast'] = y_predict
  56. result.loc[result['power_forecast'] < 0, 'power_forecast'] = 0
  57. return result[['date_time','farm_id','power_forecast']]
  58. def model_prediction(df,args):
  59. if 'is_limit' in df.columns:
  60. df = df[df['is_limit'] == False]
  61. features, time_steps, col_time, model_name,col_reserve,howlongago = str_to_list(args['features']), int(args['time_steps']),args['col_time'],args['model_name'],str_to_list(args['col_reserve']),int(args['howlongago'])
  62. feature_scaler,target_scaler = get_scaler_model_from_mongo(args)
  63. df = df.sort_values(by=col_time).fillna(method='ffill').fillna(method='bfill')
  64. scaled_features = feature_scaler.transform(df[features])
  65. X_predict, _ = create_sequences(scaled_features, [], time_steps)
  66. # 加载模型时传入自定义损失函数
  67. # model = load_model(f'{farmId}_model.h5', custom_objects={'rmse': rmse})
  68. model = get_h5_model_from_mongo(args)
  69. y_predict = list(chain.from_iterable(target_scaler.inverse_transform([model.predict(X_predict).flatten()])))
  70. result['howlongago'] = howlongago
  71. result = df[-len(y_predict):]
  72. result['predict'] = y_predict
  73. result.loc[result['predict'] < 0, 'predict'] = 0
  74. result['model'] = model_name
  75. features_reserve = col_reserve + ['model', 'predict', 'howlongago']
  76. return result[set(features_reserve)]
  77. @app.route('/model_prediction_lstm', methods=['POST'])
  78. def model_prediction_lstm():
  79. # 获取程序开始时间
  80. start_time = time.time()
  81. result = {}
  82. success = 0
  83. print("Program starts execution!")
  84. try:
  85. args = request.values.to_dict()
  86. print('args',args)
  87. logger.info(args)
  88. forecast_file = int(args['forecast_file'])
  89. power_df = get_data_from_mongo(args)
  90. if forecast_file == 1:
  91. predict_data = forecast_data_distribution(power_df,args)
  92. else:
  93. predict_data = model_prediction(power_df,args)
  94. insert_data_into_mongo(predict_data,args)
  95. success = 1
  96. except Exception as e:
  97. my_exception = traceback.format_exc()
  98. my_exception.replace("\n","\t")
  99. result['msg'] = my_exception
  100. end_time = time.time()
  101. result['success'] = success
  102. result['args'] = args
  103. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  104. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  105. print("Program execution ends!")
  106. return result
  107. if __name__=="__main__":
  108. print("Program starts execution!")
  109. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  110. logger = logging.getLogger("model_prediction_lstm log")
  111. from waitress import serve
  112. serve(app, host="0.0.0.0", port=10097)
  113. print("server start!")