model_prediction_ml.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import pandas as pd
  2. from pymongo import MongoClient
  3. import pickle
  4. from flask import Flask, request, g
  5. import time
  6. import logging
  7. import traceback
  8. from common.database_dml import get_data_from_mongo, insert_data_into_mongo
  9. from common.alert import send_message
  10. from datetime import datetime, timedelta
  11. import pytz
  12. from pytz import timezone
  13. from common.log_utils import init_request_logging, teardown_request_logging
  14. from common.processing_data_common import get_xxl_dq
  15. app = Flask('model_prediction_ml——service')
  16. @app.before_request
  17. def setup_logging():
  18. init_request_logging(logger)
  19. # 请求后清理日志处理器
  20. @app.after_request
  21. def teardown_logging(response):
  22. return teardown_request_logging(response, logger)
  23. def str_to_list(arg):
  24. if arg == '':
  25. return []
  26. else:
  27. return arg.split(',')
  28. def forecast_data_distribution(pre_data, args):
  29. col_time = args['col_time']
  30. farm_id = args['farm_id']
  31. dt = datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai")).strftime('%Y%m%d')
  32. tomorrow = (datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai")) + timedelta(days=1)).strftime('%Y-%m-%d')
  33. field_mapping = {'clearsky_ghi': 'clearskyGhi', 'dni_calcd': 'dniCalcd', 'surface_pressure': 'surfacePressure',
  34. 'wd140m': 'tj_wd140', 'ws140m': 'tj_ws140', 'wd170m': 'tj_wd170', 'cldt': 'tj_tcc',
  35. 'wd70m': 'tj_wd70',
  36. 'ws100m': 'tj_ws100', 'DSWRFsfc': 'tj_radiation', 'wd10m': 'tj_wd10', 'TMP2m': 'tj_t2',
  37. 'wd30m': 'tj_wd30',
  38. 'ws30m': 'tj_ws30', 'rh2m': 'tj_rh', 'PRATEsfc': 'tj_pratesfc', 'ws170m': 'tj_ws170',
  39. 'wd50m': 'tj_wd50',
  40. 'ws50m': 'tj_ws50', 'wd100m': 'tj_wd100', 'ws70m': 'tj_ws70', 'ws10m': 'tj_ws10',
  41. 'psz': 'tj_pressure',
  42. 'cldl': 'tj_lcc', 'pres': 'tj_pres', 'dateTime': 'date_time'}
  43. # 根据字段映射重命名列
  44. pre_data = pre_data.rename(columns=field_mapping)
  45. if len(pre_data) == 0:
  46. logging.info("nwp dataframe is empty")
  47. send_message('lightgbm预测组件', farm_id, '请注意:获取NWP数据为空,预测文件无法生成!')
  48. result = get_xxl_dq(farm_id, dt)
  49. elif len(pre_data[pre_data[col_time].str.contains(tomorrow)]) < 96:
  50. send_message('lightgbm预测组件', farm_id, "日前数据记录缺失,不足96条,用DQ代替并补值!")
  51. result = get_xxl_dq(farm_id, dt)
  52. else:
  53. df = pre_data.sort_values(by=col_time).fillna(method='ffill').fillna(method='bfill')
  54. mongodb_connection, mongodb_database, mongodb_model_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
  55. args['mongodb_database'], args['mongodb_model_table'], args['model_name']
  56. client = MongoClient(mongodb_connection)
  57. db = client[mongodb_database]
  58. collection = db[mongodb_model_table]
  59. model_data = collection.find_one({"model_name": model_name})
  60. logger.info(f"use {model_data['model_name']} features: {model_data['features']}")
  61. if model_data is not None:
  62. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  63. # 反序列化模型
  64. model = pickle.loads(model_binary)
  65. if 'features' in model_data.keys():
  66. features = model_data['features']
  67. else:
  68. features = model.feature_name()
  69. diff = set(features) - set(pre_data.columns)
  70. if len(diff) > 0:
  71. send_message('lightgbm预测组件', farm_id, f'NWP特征列缺失,使用DQ代替!features:{diff}')
  72. result = get_xxl_dq(farm_id, dt)
  73. else:
  74. df['power_forecast'] = model.predict(df[features])
  75. df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
  76. logger.info("model predict result successfully!")
  77. if 'farm_id' not in df.columns:
  78. df['farm_id'] = farm_id
  79. result = df[['farm_id', 'date_time', 'power_forecast']]
  80. else:
  81. send_message('lightgbm预测组件', farm_id, "模型文件缺失,用DQ代替并补值!")
  82. result = get_xxl_dq(farm_id, dt)
  83. result['power_forecast'] = round(result['power_forecast'], 2)
  84. return result
  85. def model_prediction(df, args):
  86. mongodb_connection, mongodb_database, mongodb_model_table, model_name, howLongAgo, farm_id, target = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
  87. args['mongodb_database'], args['mongodb_model_table'], args['model_name'], int(args['howLongAgo']), args['farm_id'], \
  88. args['target']
  89. client = MongoClient(mongodb_connection)
  90. db = client[mongodb_database]
  91. collection = db[mongodb_model_table]
  92. model_data = collection.find_one({"model_name": model_name})
  93. if 'is_limit' in df.columns:
  94. df = df[df['is_limit'] == False]
  95. logger.info(f"use {model_data['model_name']} features: {model_data['features']}")
  96. if model_data is not None:
  97. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  98. # 反序列化模型
  99. model = pickle.loads(model_binary)
  100. if 'features' in model_data.keys():
  101. features = model_data['features']
  102. else:
  103. features = model.feature_name()
  104. df.dropna(subset=features, inplace=True)
  105. df['power_forecast'] = model.predict(df[features])
  106. df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
  107. df['model'] = model_name
  108. df['howLongAgo'] = howLongAgo
  109. df['farm_id'] = farm_id
  110. logger.info("model predict result successfully!")
  111. return df[['dateTime', 'howLongAgo', 'model', 'farm_id', 'power_forecast', target]]
  112. @app.route('/model_prediction_ml', methods=['POST'])
  113. def model_prediction_ml():
  114. # 获取程序开始时间
  115. start_time = time.time()
  116. result = {}
  117. success = 0
  118. print("Program starts execution!")
  119. try:
  120. args = request.values.to_dict()
  121. print('args', args)
  122. forecast_file = int(args['forecast_file'])
  123. power_df = get_data_from_mongo(args)
  124. if forecast_file == 1:
  125. predict_data = forecast_data_distribution(power_df, args)
  126. else:
  127. predict_data = model_prediction(power_df, args)
  128. insert_data_into_mongo(predict_data, args)
  129. success = 1
  130. except Exception as e:
  131. my_exception = traceback.format_exc()
  132. logger.error(my_exception)
  133. end_time = time.time()
  134. result['success'] = success
  135. result['args'] = args
  136. result['log'] = result['log'] = g.log_stream.getvalue().splitlines()
  137. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  138. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  139. print("Program execution ends!")
  140. return result
  141. if __name__ == "__main__":
  142. print("Program starts execution!")
  143. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  144. logger = logging.getLogger("model_prediction_ml log")
  145. from waitress import serve
  146. serve(app, host="0.0.0.0", port=10129)
  147. print("server start!")