|
@@ -0,0 +1,158 @@
|
|
|
|
+import pandas as pd
|
|
|
|
+from pymongo import MongoClient
|
|
|
|
+import pickle
|
|
|
|
+from flask import Flask, request
|
|
|
|
+import time
|
|
|
|
+import logging
|
|
|
|
+import traceback
|
|
|
|
+from common.database_dml import get_data_from_mongo, insert_data_into_mongo
|
|
|
|
+from common.alert import send_message
|
|
|
|
+from datetime import datetime, timedelta
|
|
|
|
+import pytz
|
|
|
|
+from pytz import timezone
|
|
|
|
+
|
|
|
|
+from common.processing_data_common import get_xxl_dq
|
|
|
|
+
|
|
|
|
+app = Flask('model_prediction_ml——service')
|
|
|
|
+
|
|
|
|
+"""
|
|
|
|
+于model_training_lightgbm.py
|
|
|
|
+1. 支持特征保存模型预测
|
|
|
|
+"""
|
|
|
|
+
|
|
|
|
+def str_to_list(arg):
|
|
|
|
+ if arg == '':
|
|
|
|
+ return []
|
|
|
|
+ else:
|
|
|
|
+ return arg.split(',')
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def forecast_data_distribution(pre_data, args):
|
|
|
|
+ col_time = args['col_time']
|
|
|
|
+ farm_id = args['farmId']
|
|
|
|
+ dt = datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai")).strftime('%Y%m%d')
|
|
|
|
+ tomorrow = (datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai")) + timedelta(days=1)).strftime('%Y-%m-%d')
|
|
|
|
+ field_mapping = {'clearsky_ghi': 'clearskyGhi', 'dni_calcd': 'dniCalcd', 'surface_pressure': 'surfacePressure',
|
|
|
|
+ 'wd140m': 'tj_wd140', 'ws140m': 'tj_ws140', 'wd170m': 'tj_wd170', 'cldt': 'tj_tcc',
|
|
|
|
+ 'wd70m': 'tj_wd70',
|
|
|
|
+ 'ws100m': 'tj_ws100', 'DSWRFsfc': 'tj_radiation', 'wd10m': 'tj_wd10', 'TMP2m': 'tj_t2',
|
|
|
|
+ 'wd30m': 'tj_wd30',
|
|
|
|
+ 'ws30m': 'tj_ws30', 'rh2m': 'tj_rh', 'PRATEsfc': 'tj_pratesfc', 'ws170m': 'tj_ws170',
|
|
|
|
+ 'wd50m': 'tj_wd50',
|
|
|
|
+ 'ws50m': 'tj_ws50', 'wd100m': 'tj_wd100', 'ws70m': 'tj_ws70', 'ws10m': 'tj_ws10',
|
|
|
|
+ 'psz': 'tj_pressure',
|
|
|
|
+ 'cldl': 'tj_lcc', 'pres': 'tj_pres', 'dateTime': 'date_time'}
|
|
|
|
+ # 根据字段映射重命名列
|
|
|
|
+ pre_data = pre_data.rename(columns=field_mapping)
|
|
|
|
+
|
|
|
|
+ if len(pre_data) == 0:
|
|
|
|
+ send_message('lightgbm预测组件', farm_id, '请注意:获取NWP数据为空,预测文件无法生成!')
|
|
|
|
+ result = get_xxl_dq(farm_id, dt)
|
|
|
|
+
|
|
|
|
+ elif len(pre_data[pre_data[col_time].str.contains(tomorrow)]) < 96:
|
|
|
|
+ send_message('lightgbm预测组件', farm_id, "日前数据记录缺失,不足96条,用DQ代替并补值!")
|
|
|
|
+ result = get_xxl_dq(farm_id, dt)
|
|
|
|
+ else:
|
|
|
|
+ df = pre_data.sort_values(by=col_time).fillna(method='ffill').fillna(method='bfill')
|
|
|
|
+ mongodb_connection, mongodb_database, mongodb_model_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
|
|
|
|
+ args['mongodb_database'], args['mongodb_model_table'], args['model_name']
|
|
|
|
+ client = MongoClient(mongodb_connection)
|
|
|
|
+ db = client[mongodb_database]
|
|
|
|
+ collection = db[mongodb_model_table]
|
|
|
|
+ model_data = collection.find_one({"model_name": model_name})
|
|
|
|
+ print(model_data.keys())
|
|
|
|
+ if model_data is not None:
|
|
|
|
+ model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
|
|
|
|
+ # 反序列化模型
|
|
|
|
+ model = pickle.loads(model_binary)
|
|
|
|
+ if 'features' in model_data.keys():
|
|
|
|
+ features = model_data['features']
|
|
|
|
+ else:
|
|
|
|
+ features = model.feature_name()
|
|
|
|
+ diff = set(features) - set(pre_data.columns)
|
|
|
|
+ if len(diff) > 0:
|
|
|
|
+ send_message('lightgbm预测组件', farm_id, f'NWP特征列缺失,使用DQ代替!features:{diff}')
|
|
|
|
+ result = get_xxl_dq(farm_id, dt)
|
|
|
|
+ else:
|
|
|
|
+ df['power_forecast'] = model.predict(df[features])
|
|
|
|
+ df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
|
|
|
|
+ print("model predict result successfully!")
|
|
|
|
+ if 'farm_id' not in df.columns:
|
|
|
|
+ df['farm_id'] = farm_id
|
|
|
|
+ result = df[['farm_id', 'date_time', 'power_forecast']]
|
|
|
|
+ else:
|
|
|
|
+ send_message('lightgbm预测组件', farm_id, "模型文件缺失,用DQ代替并补值!")
|
|
|
|
+ result = get_xxl_dq(farm_id, dt)
|
|
|
|
+ result['power_forecast'] = round(result['power_forecast'], 2)
|
|
|
|
+ return result
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def model_prediction(df, args):
|
|
|
|
+ mongodb_connection, mongodb_database, mongodb_model_table, model_name, howLongAgo, farm_id, target = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
|
|
|
|
+ args['mongodb_database'], args['mongodb_model_table'], args['model_name'], int(args['howLongAgo']), args['farm_id'], \
|
|
|
|
+ args['target']
|
|
|
|
+ client = MongoClient(mongodb_connection)
|
|
|
|
+ db = client[mongodb_database]
|
|
|
|
+ collection = db[mongodb_model_table]
|
|
|
|
+ model_data = collection.find_one({"model_name": model_name})
|
|
|
|
+ if 'is_limit' in df.columns:
|
|
|
|
+ df = df[df['is_limit'] == False]
|
|
|
|
+
|
|
|
|
+ if model_data is not None:
|
|
|
|
+ model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
|
|
|
|
+ # 反序列化模型
|
|
|
|
+ model = pickle.loads(model_binary)
|
|
|
|
+ if 'features' in model_data.keys():
|
|
|
|
+ features = model_data['features']
|
|
|
|
+ else:
|
|
|
|
+ features = model.feature_name()
|
|
|
|
+ df['power_forecast'] = model.predict(df[features])
|
|
|
|
+ df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
|
|
|
|
+ df['model'] = model_name
|
|
|
|
+ df['howLongAgo'] = howLongAgo
|
|
|
|
+ df['farm_id'] = farm_id
|
|
|
|
+ print("model predict result successfully!")
|
|
|
|
+
|
|
|
|
+ return df[['dateTime', 'howLongAgo', 'model', 'farm_id', 'power_forecast', target]]
|
|
|
|
+
|
|
|
|
+@app.route('/model_prediction_ml', methods=['POST'])
|
|
|
|
+def model_prediction_ml():
|
|
|
|
+ # 获取程序开始时间
|
|
|
|
+ start_time = time.time()
|
|
|
|
+ result = {}
|
|
|
|
+ success = 0
|
|
|
|
+ print("Program starts execution!")
|
|
|
|
+ try:
|
|
|
|
+ args = request.values.to_dict()
|
|
|
|
+ print('args', args)
|
|
|
|
+ forecast_file = int(args['forecast_file'])
|
|
|
|
+ power_df = get_data_from_mongo(args)
|
|
|
|
+ if forecast_file == 1:
|
|
|
|
+ predict_data = forecast_data_distribution(power_df, args)
|
|
|
|
+ else:
|
|
|
|
+ predict_data = model_prediction(power_df, args)
|
|
|
|
+ insert_data_into_mongo(predict_data, args)
|
|
|
|
+ success = 1
|
|
|
|
+ except Exception as e:
|
|
|
|
+ my_exception = traceback.format_exc()
|
|
|
|
+ my_exception.replace("\n", "\t")
|
|
|
|
+ result['msg'] = my_exception
|
|
|
|
+ end_time = time.time()
|
|
|
|
+
|
|
|
|
+ result['success'] = success
|
|
|
|
+ result['args'] = args
|
|
|
|
+ result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
|
|
|
|
+ result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
|
|
|
|
+ print("Program execution ends!")
|
|
|
|
+ return result
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
+ print("Program starts execution!")
|
|
|
|
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
+ logger = logging.getLogger("model_prediction_ml log")
|
|
|
|
+ from waitress import serve
|
|
|
|
+
|
|
|
|
+ serve(app, host="0.0.0.0", port=10126)
|
|
|
|
+ print("server start!")
|