model_prediction_lightgbm.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import pandas as pd
  2. from pymongo import MongoClient
  3. import pickle
  4. from flask import Flask,request
  5. import time
  6. import logging
  7. import traceback
  8. from common.database_dml import get_data_from_mongo,insert_data_into_mongo
  9. from common.alert import send_message
  10. from datetime import datetime, timedelta
  11. import pytz
  12. from pytz import timezone
  13. from common.processing_data_common import get_xxl_dq
  14. app = Flask('model_prediction_lightgbm——service')
  15. def str_to_list(arg):
  16. if arg == '':
  17. return []
  18. else:
  19. return arg.split(',')
  20. def forecast_data_distribution(pre_data,args):
  21. col_time = args['col_time']
  22. farm_id = args['farmId']
  23. dt = datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai")).strftime('%Y%m%d')
  24. tomorrow = (datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai")) + timedelta(days=1)).strftime('%Y-%m-%d')
  25. field_mapping = {'clearsky_ghi': 'clearskyGhi', 'dni_calcd': 'dniCalcd','surface_pressure':'surfacePressure',
  26. 'wd140m': 'tj_wd140','ws140m': 'tj_ws140','wd170m': 'tj_wd170','cldt': 'tj_tcc','wd70m': 'tj_wd70',
  27. 'ws100m': 'tj_ws100','DSWRFsfc': 'tj_radiation','wd10m': 'tj_wd10','TMP2m': 'tj_t2','wd30m': 'tj_wd30',
  28. 'ws30m': 'tj_ws30','rh2m': 'tj_rh','PRATEsfc': 'tj_pratesfc','ws170m': 'tj_ws170','wd50m': 'tj_wd50',
  29. 'ws50m': 'tj_ws50','wd100m': 'tj_wd100','ws70m': 'tj_ws70','ws10m': 'tj_ws10','psz': 'tj_pressure',
  30. 'cldl': 'tj_lcc','pres': 'tj_pres','dateTime':'date_time'}
  31. # 根据字段映射重命名列
  32. pre_data = pre_data.rename(columns=field_mapping)
  33. if len(pre_data)==0:
  34. send_message('lightgbm预测组件', farm_id, '请注意:获取NWP数据为空,预测文件无法生成!')
  35. result = get_xxl_dq(farm_id, dt)
  36. elif len(pre_data[pre_data[col_time].str.contains(tomorrow)])<96:
  37. send_message('lightgbm预测组件', farm_id, "日前数据记录缺失,不足96条,用DQ代替并补值!")
  38. result = get_xxl_dq(farm_id, dt)
  39. else:
  40. df = pre_data.sort_values(by=col_time).fillna(method='ffill').fillna(method='bfill')
  41. mongodb_connection, mongodb_database, mongodb_model_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
  42. args['mongodb_database'], args['mongodb_model_table'], args['model_name']
  43. client = MongoClient(mongodb_connection)
  44. db = client[mongodb_database]
  45. collection = db[mongodb_model_table]
  46. model_data = collection.find_one({"model_name": model_name})
  47. if model_data is not None:
  48. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  49. # 反序列化模型
  50. model = pickle.loads(model_binary)
  51. diff = set(model.feature_name()) - set(pre_data.columns)
  52. if len(diff) > 0:
  53. send_message('lightgbm预测组件', farm_id, f'NWP特征列缺失,使用DQ代替!features:{diff}')
  54. result = get_xxl_dq(farm_id, dt)
  55. else:
  56. df['power_forecast'] = model.predict(df[model.feature_name()])
  57. df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
  58. print("model predict result successfully!")
  59. if 'farm_id' not in df.columns:
  60. df['farm_id'] = farm_id
  61. result = df[['farm_id', 'date_time', 'power_forecast']]
  62. else:
  63. send_message('lightgbm预测组件', farm_id, "模型文件缺失,用DQ代替并补值!")
  64. result = get_xxl_dq(farm_id, dt)
  65. result['power_forecast'] = round(result['power_forecast'],2)
  66. return result
  67. def model_prediction(df,args):
  68. mongodb_connection,mongodb_database,mongodb_model_table,model_name,howLongAgo,farm_id,target = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_model_table'],args['model_name'],int(args['howLongAgo']),args['farm_id'],args['target']
  69. client = MongoClient(mongodb_connection)
  70. db = client[mongodb_database]
  71. collection = db[mongodb_model_table]
  72. model_data = collection.find_one({"model_name": model_name})
  73. if 'is_limit' in df.columns:
  74. df = df[df['is_limit'] == False]
  75. if model_data is not None:
  76. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  77. # 反序列化模型
  78. model = pickle.loads(model_binary)
  79. df['power_forecast'] = model.predict(df[model.feature_name()])
  80. df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
  81. df['model'] = model_name
  82. df['howLongAgo'] = howLongAgo
  83. df['farm_id'] = farm_id
  84. print("model predict result successfully!")
  85. return df[['dateTime','howLongAgo','model','farm_id','power_forecast',target]]
  86. @app.route('/model_prediction_lightgbm', methods=['POST'])
  87. def model_prediction_lightgbm():
  88. # 获取程序开始时间
  89. start_time = time.time()
  90. result = {}
  91. success = 0
  92. print("Program starts execution!")
  93. try:
  94. args = request.values.to_dict()
  95. print('args',args)
  96. logger.info(args)
  97. forecast_file = int(args['forecast_file'])
  98. power_df = get_data_from_mongo(args)
  99. if forecast_file == 1:
  100. predict_data = forecast_data_distribution(power_df, args)
  101. else:
  102. predict_data = model_prediction(power_df, args)
  103. insert_data_into_mongo(predict_data,args)
  104. success = 1
  105. except Exception as e:
  106. my_exception = traceback.format_exc()
  107. my_exception.replace("\n","\t")
  108. result['msg'] = my_exception
  109. end_time = time.time()
  110. result['success'] = success
  111. result['args'] = args
  112. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  113. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  114. print("Program execution ends!")
  115. return result
  116. if __name__=="__main__":
  117. print("Program starts execution!")
  118. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  119. logger = logging.getLogger("model_prediction_lightgbm log")
  120. from waitress import serve
  121. serve(app, host="0.0.0.0", port=10090)
  122. print("server start!")