Ver código fonte

feat 支持svr与样本权重

hzh 2 semanas atrás
pai
commit
f006a988e9

+ 3 - 1
data_processing/data_operation/weight.py

@@ -1,10 +1,11 @@
 import numpy as np
 
 
-def balance_weights(y: np.ndarray, bins: int = 10, normalize: bool = True, **kwargs) -> np.ndarray:
+def balance_weights(y: np.ndarray, bins=10, normalize: bool = True, **kwargs) -> np.ndarray:
     """
     平衡权重,分布数量越少权重越大
     """
+    bins = int(bins)
     counts, bin_edges = np.histogram(y, bins=bins)
 
     # digitize 不使用 right=True,这样最小值也能落在 bin 0 开始
@@ -26,6 +27,7 @@ def south_weight(target: np.ndarray, cap, **kwargs) -> np.ndarray:
     应付南方点网的奇怪考核
     为了不把曲线压太低,这里有所收敛(添加开方处理,不让权重分布过于离散)
     """
+    cap = float(cap)
     weight = 1 / np.sqrt(np.where(target < 0.2 * cap, 0.2 * cap, target))
     return weight
 

+ 2 - 2
models_processing/model_predict/model_prediction_ml.py

@@ -26,7 +26,6 @@ def str_to_list(arg):
     else:
         return arg.split(',')
 
-
 def forecast_data_distribution(pre_data, args):
     col_time = args['col_time']
     farm_id = args['farmId']
@@ -86,7 +85,6 @@ def forecast_data_distribution(pre_data, args):
     result['power_forecast'] = round(result['power_forecast'], 2)
     return result
 
-
 def model_prediction(df, args):
     mongodb_connection, mongodb_database, mongodb_model_table, model_name, howLongAgo, farm_id, target = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
     args['mongodb_database'], args['mongodb_model_table'], args['model_name'], int(args['howLongAgo']), args['farm_id'], \
@@ -106,6 +104,7 @@ def model_prediction(df, args):
             features = model_data['features']
         else:
             features = model.feature_name()
+        df.dropna(subset=features, inplace=True)
         df['power_forecast'] = model.predict(df[features])
         df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
         df['model'] = model_name
@@ -135,6 +134,7 @@ def model_prediction_ml():
         success = 1
     except Exception as e:
         my_exception = traceback.format_exc()
+        print(my_exception)
         my_exception.replace("\n", "\t")
         result['msg'] = my_exception
     end_time = time.time()

+ 20 - 11
models_processing/model_train/model_training_ml.py

@@ -15,6 +15,18 @@ from data_processing.data_operation.weight import WEIGHT_REGISTER
 
 app = Flask('model_training_ml——service')
 
+def get_sample_weight(df, label, args):
+    # 样本权重
+    if 'sample_weight' in args.keys():
+        if args['sample_weight'] in WEIGHT_REGISTER.keys():
+            sample_weight = WEIGHT_REGISTER[args['sample_weight']](df[label].values.reshape(-1), **args)
+        elif args['sample_weight'] in df.columns.tolist():
+            sample_weight = df[args['sample_weight']].values.reshape(-1)
+        else:
+            sample_weight = None
+            print('sample_weight is neither in the predefined weights nor a column of the DataFrame, not applicable')
+    return sample_weight
+
 def train_lgb(data_split, categorical_features, model_params, num_boost_round=100, sample_weight=None):
     X_train, X_test, y_train, y_test = data_split
     # 创建LightGBM数据集
@@ -67,19 +79,15 @@ def build_model(df, args):
     df = missing_features(df, features, col_time)
     df = df[~np.isnan(df[label])]
     # 拆分数据为训练集和测试集
-    X_train, X_test, y_train, y_test = train_test_split(df[features], df[label], test_size=0.2, random_state=42,
+    df_train, df_test = train_test_split(df, test_size=0.2, random_state=42,
                                                         shuffle=False)
+    X_train, y_train = df_train[features].values, df_train[label].values
+    X_test, y_test = df_test[features].values, df_test[label].values
+
+    # 获取样本权重
+    sample_weight = get_sample_weight(df_train, label=label, args=args)
+
     model_type = args['model_type']
-    sample_weight = None
-    # 样本权重
-    if 'sample_weight' in args.keys():
-        if args['sample_weight'] in WEIGHT_REGISTER.keys():
-            sample_weight = WEIGHT_REGISTER[args['sample_weight']](df[label].values.reshape(-1), **args)
-        elif args['sample_weight'] in df.columns.tolist():
-            sample_weight = df[args['sample_weight']].values.reshape(-1)
-        else:
-            sample_weight = None
-            print('sample_weight is neither in the predefined weights nor a column of the DataFrame, not applicable')
     # 区分常规机器学习模型和lgb,这里只实例化svr,后续可扩展
     if model_type == "lightgbm":
         num_boost_round = int(args['num_boost_round'])
@@ -114,6 +122,7 @@ def model_training_ml():
         insert_pickle_model_into_mongo(model, args, features=features)
         success = 1
     except Exception as e:
+        print(e)
         my_exception = traceback.format_exc()
         my_exception.replace("\n", "\t")
         result['msg'] = my_exception

+ 118 - 0
post_processing/post_process.py

@@ -0,0 +1,118 @@
+import pandas as pd
+from flask import Flask, request, jsonify
+import time
+import logging
+import traceback
+
+from common.database_dml import get_data_from_mongo, insert_data_into_mongo
+
+app = Flask('post_processing——service')
+
+"""
+id = "${id}"
+
+cap = ${cap}
+
+参数
+{
+    'mongodb_database': 'hzh_ftp',
+    'mongodb_read_table': f'{id}_PRED',
+    'mongodb_write_table': f'{id}_PRED',
+    'col_time':  "dateTime",
+    'smooth_window': 3
+    'plant_type': 'solar',
+    'mongodb_nwp_table': f'{id}_NWP_D1'
+}
+
+"""
+
+
+def get_data(args):
+    df = get_data_from_mongo(args)
+    col_time = args['col_time']
+    if not df.empty:
+        print("预测数据加载成功!")
+        df[col_time] = pd.to_datetime(df[col_time])
+        df.set_index(col_time, inplace=True)
+        df.sort_index(inplace=True)
+    else:
+        raise ValueError("未获取到预测数据。")
+    return df
+
+
+def predict_result_adjustment(df, args):
+    """
+    光伏/风电 数据后处理 主要操作
+    1. 光伏 (夜间 置零 + 平滑)
+    2. 风电 (平滑)
+    3. cap 封顶
+    """
+    mongodb_database, plant_type, cap, col_time = args['mongodb_database'], args['plant_type'], float(args['cap']), \
+        args['col_time']
+    if 'smooth_window' in args.keys():
+        smooth_window = int(args['smooth_window'])
+    else:
+        smooth_window = 3
+
+    # 平滑
+    df_cp = df.copy()
+    df_cp['power_forecast'] = df_cp['power_forecast'].rolling(window=smooth_window, min_periods=1,
+                                                              center=True).mean().clip(0, 0.985 * cap)
+    print("smooth processed")
+
+    # 光伏晚上置零
+    if plant_type == 'solar' and 'mongodb_nwp_table' in args.keys():
+        nwp_param = {
+            'mongodb_database': mongodb_database,
+            'mongodb_read_table': args['mongodb_nwp_table'],
+            'col_time': col_time
+        }
+        nwp = get_data(nwp_param)
+
+        df_cp = df_cp.join(nwp['radiation'])
+        df_cp.loc[nwp['radiation'] == 0, 'power_forecast'] = 0
+        df_cp['power_forecast'] = round(df_cp['power_forecast'], 2)
+        df_cp.drop(columns=['radiation'], inplace=True)
+        print("solar processed")
+    df_cp.reset_index(inplace=True)
+    df_cp[col_time] = df_cp[col_time].dt.strftime('%Y-%m-%d %H:%M:%S')
+    return df_cp
+
+
+@app.route('/post_process', methods=['POST'])
+def data_join():
+    # 获取程序开始时间
+    start_time = time.time()
+    result = {}
+    success = 0
+    print("Program starts execution!")
+    try:
+        args = request.values.to_dict()
+        print('args', args)
+        logger.info(args)
+        df_pre = get_data(args)
+        res_df = predict_result_adjustment(df_pre, args)
+        insert_data_into_mongo(res_df, args)
+        success = 1
+    except Exception as e:
+        my_exception = traceback.format_exc()
+        print(my_exception)
+        my_exception.replace("\n", "\t")
+        result['msg'] = my_exception
+    end_time = time.time()
+    result['success'] = success
+    result['args'] = args
+    result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
+    result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
+    print("Program execution ends!")
+    return result
+
+
+if __name__ == "__main__":
+    print("Program starts execution!")
+    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    logger = logging.getLogger("post_processing")
+    from waitress import serve
+
+    serve(app, host="0.0.0.0", port=10098)
+    print("server start!")