Explorar el Código

feat 支持lightgbm和SVR

hzh hace 1 mes
padre
commit
0cc21a30c7

+ 4 - 1
common/database_dml.py

@@ -127,7 +127,7 @@ def get_data_fromMysql(params):
     return df
 
 
-def insert_pickle_model_into_mongo(model, args):
+def insert_pickle_model_into_mongo(model, args, features=None):
     mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
     args['mongodb_database'], args['mongodb_write_table'], args['model_name']
     client = MongoClient(mongodb_connection)
@@ -138,6 +138,9 @@ def insert_pickle_model_into_mongo(model, args):
         'model_name': model_name,
         'model': model_bytes,  # 将模型字节流存入数据库
     }
+    # 保存模型特征
+    if features is not None:
+        model_data['features'] = features
     print('Training completed!')
 
     if mongodb_write_table in db.list_collection_names():

+ 1 - 2
models_processing/model_predict/model_prediction_lightgbm.py

@@ -134,5 +134,4 @@ if __name__=="__main__":
     logger = logging.getLogger("model_prediction_lightgbm log")
     from waitress import serve
     serve(app, host="0.0.0.0", port=10090)
-    print("server start!")
-    
+    print("server start!")

+ 158 - 0
models_processing/model_predict/model_prediction_ml.py

@@ -0,0 +1,158 @@
+import pandas as pd
+from pymongo import MongoClient
+import pickle
+from flask import Flask, request
+import time
+import logging
+import traceback
+from common.database_dml import get_data_from_mongo, insert_data_into_mongo
+from common.alert import send_message
+from datetime import datetime, timedelta
+import pytz
+from pytz import timezone
+
+from common.processing_data_common import get_xxl_dq
+
+app = Flask('model_prediction_ml——service')
+
+"""
+于model_training_lightgbm.py
+1. 支持特征保存模型预测
+"""
+
+def str_to_list(arg):
+    if arg == '':
+        return []
+    else:
+        return arg.split(',')
+
+
+def forecast_data_distribution(pre_data, args):
+    col_time = args['col_time']
+    farm_id = args['farmId']
+    dt = datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai")).strftime('%Y%m%d')
+    tomorrow = (datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai")) + timedelta(days=1)).strftime('%Y-%m-%d')
+    field_mapping = {'clearsky_ghi': 'clearskyGhi', 'dni_calcd': 'dniCalcd', 'surface_pressure': 'surfacePressure',
+                     'wd140m': 'tj_wd140', 'ws140m': 'tj_ws140', 'wd170m': 'tj_wd170', 'cldt': 'tj_tcc',
+                     'wd70m': 'tj_wd70',
+                     'ws100m': 'tj_ws100', 'DSWRFsfc': 'tj_radiation', 'wd10m': 'tj_wd10', 'TMP2m': 'tj_t2',
+                     'wd30m': 'tj_wd30',
+                     'ws30m': 'tj_ws30', 'rh2m': 'tj_rh', 'PRATEsfc': 'tj_pratesfc', 'ws170m': 'tj_ws170',
+                     'wd50m': 'tj_wd50',
+                     'ws50m': 'tj_ws50', 'wd100m': 'tj_wd100', 'ws70m': 'tj_ws70', 'ws10m': 'tj_ws10',
+                     'psz': 'tj_pressure',
+                     'cldl': 'tj_lcc', 'pres': 'tj_pres', 'dateTime': 'date_time'}
+    # 根据字段映射重命名列
+    pre_data = pre_data.rename(columns=field_mapping)
+
+    if len(pre_data) == 0:
+        send_message('lightgbm预测组件', farm_id, '请注意:获取NWP数据为空,预测文件无法生成!')
+        result = get_xxl_dq(farm_id, dt)
+
+    elif len(pre_data[pre_data[col_time].str.contains(tomorrow)]) < 96:
+        send_message('lightgbm预测组件', farm_id, "日前数据记录缺失,不足96条,用DQ代替并补值!")
+        result = get_xxl_dq(farm_id, dt)
+    else:
+        df = pre_data.sort_values(by=col_time).fillna(method='ffill').fillna(method='bfill')
+        mongodb_connection, mongodb_database, mongodb_model_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
+            args['mongodb_database'], args['mongodb_model_table'], args['model_name']
+        client = MongoClient(mongodb_connection)
+        db = client[mongodb_database]
+        collection = db[mongodb_model_table]
+        model_data = collection.find_one({"model_name": model_name})
+        print(model_data.keys())
+        if model_data is not None:
+            model_binary = model_data['model']  # 确保这个字段是存储模型的二进制数据
+            # 反序列化模型
+            model = pickle.loads(model_binary)
+            if 'features' in model_data.keys():
+                features = model_data['features']
+            else:
+                features = model.feature_name()
+            diff = set(features) - set(pre_data.columns)
+            if len(diff) > 0:
+                send_message('lightgbm预测组件', farm_id, f'NWP特征列缺失,使用DQ代替!features:{diff}')
+                result = get_xxl_dq(farm_id, dt)
+            else:
+                df['power_forecast'] = model.predict(df[features])
+                df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
+                print("model predict result  successfully!")
+                if 'farm_id' not in df.columns:
+                    df['farm_id'] = farm_id
+                result = df[['farm_id', 'date_time', 'power_forecast']]
+        else:
+            send_message('lightgbm预测组件', farm_id, "模型文件缺失,用DQ代替并补值!")
+            result = get_xxl_dq(farm_id, dt)
+    result['power_forecast'] = round(result['power_forecast'], 2)
+    return result
+
+
+def model_prediction(df, args):
+    mongodb_connection, mongodb_database, mongodb_model_table, model_name, howLongAgo, farm_id, target = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
+    args['mongodb_database'], args['mongodb_model_table'], args['model_name'], int(args['howLongAgo']), args['farm_id'], \
+    args['target']
+    client = MongoClient(mongodb_connection)
+    db = client[mongodb_database]
+    collection = db[mongodb_model_table]
+    model_data = collection.find_one({"model_name": model_name})
+    if 'is_limit' in df.columns:
+        df = df[df['is_limit'] == False]
+
+    if model_data is not None:
+        model_binary = model_data['model']  # 确保这个字段是存储模型的二进制数据
+        # 反序列化模型
+        model = pickle.loads(model_binary)
+        if 'features' in model_data.keys():
+            features = model_data['features']
+        else:
+            features = model.feature_name()
+        df['power_forecast'] = model.predict(df[features])
+        df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
+        df['model'] = model_name
+        df['howLongAgo'] = howLongAgo
+        df['farm_id'] = farm_id
+        print("model predict result  successfully!")
+
+    return df[['dateTime', 'howLongAgo', 'model', 'farm_id', 'power_forecast', target]]
+
+@app.route('/model_prediction_ml', methods=['POST'])
+def model_prediction_ml():
+    # 获取程序开始时间
+    start_time = time.time()
+    result = {}
+    success = 0
+    print("Program starts execution!")
+    try:
+        args = request.values.to_dict()
+        print('args', args)
+        forecast_file = int(args['forecast_file'])
+        power_df = get_data_from_mongo(args)
+        if forecast_file == 1:
+            predict_data = forecast_data_distribution(power_df, args)
+        else:
+            predict_data = model_prediction(power_df, args)
+        insert_data_into_mongo(predict_data, args)
+        success = 1
+    except Exception as e:
+        my_exception = traceback.format_exc()
+        my_exception.replace("\n", "\t")
+        result['msg'] = my_exception
+    end_time = time.time()
+
+    result['success'] = success
+    result['args'] = args
+    result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
+    result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
+    print("Program execution ends!")
+    return result
+
+
+
+if __name__ == "__main__":
+    print("Program starts execution!")
+    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    logger = logging.getLogger("model_prediction_ml log")
+    from waitress import serve
+
+    serve(app, host="0.0.0.0", port=10126)
+    print("server start!")

+ 142 - 0
models_processing/model_train/model_training_ml.py

@@ -0,0 +1,142 @@
+import lightgbm as lgb
+import numpy as np
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import mean_squared_error, mean_absolute_error
+from flask import Flask, request
+import time
+import traceback
+import logging
+from common.database_dml import get_data_from_mongo, insert_pickle_model_into_mongo
+from common.processing_data_common import missing_features, str_to_list
+from sklearn.pipeline import Pipeline
+from sklearn.svm import SVR
+from sklearn.preprocessing import MinMaxScaler
+
+app = Flask('model_training_ml——service')
+
+"""
+基于model_training_lightgbm.py
+机器学习通用训练方法,特点
+1. 保存模型同时,保存模型特征
+2. 支持模型训练样本权重(需要在预处理部分生成权重特征)
+
+参数格式如下
+
+"""
+
+
+def train_lgb(data_split, categorical_features, model_params, num_boost_round, sample_weight=None):
+    X_train, X_test, y_train, y_test = data_split
+    # 创建LightGBM数据集
+    lgb_train = lgb.Dataset(X_train, y_train, categorical_feature=categorical_features, weight=sample_weight)
+    lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
+    # 设置参数
+    params = {
+        'objective': 'regression',
+        'metric': 'rmse',
+        'boosting_type': 'gbdt',
+        'verbose': 1
+    }
+    print(type(model_params))
+    params.update(model_params)
+    # 训练模型
+    print('Starting training...')
+    gbm = lgb.train(params,
+                    lgb_train,
+                    num_boost_round=num_boost_round,
+                    valid_sets=[lgb_train, lgb_eval],
+                    )
+    y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
+    return gbm, y_pred
+
+
+def train_svr(data_split, model_params, sample_weight=None):
+    X_train, X_test, y_train, y_test = data_split
+
+    svr = Pipeline([('scaler', MinMaxScaler()),
+                    ('model', SVR(**model_params))])
+
+    # 训练模型
+    print('Starting training...')
+    svr.fit(X_train, y_train, model__sample_weight=sample_weight)
+    y_pred = svr.predict(X_test)
+    return svr, y_pred
+
+
+def build_model(df, args):
+    np.random.seed(42)
+    # lightgbm预测下
+    numerical_features, categorical_features, label, model_name, num_boost_round, model_params, col_time = str_to_list(
+        args['numerical_features']), str_to_list(args['categorical_features']), args['label'], args['model_name'], int(
+        args['num_boost_round']), eval(args['model_params']), args['col_time']
+    # 样本权重
+    sample_weight = None
+    if 'sample_weight' in args.keys():
+        sample_weight = args['sample_weight']
+
+    features = numerical_features + categorical_features
+    print("features:************", features)
+    if 'is_limit' in df.columns:
+        df = df[df['is_limit'] == False]
+    # 清洗特征平均缺失率大于20%的天
+    df = missing_features(df, features, col_time)
+    df = df[~np.isnan(df[label])]
+    # 拆分数据为训练集和测试集
+    X_train, X_test, y_train, y_test = train_test_split(df[features], df[label], test_size=0.2, random_state=42,
+                                                        shuffle=False)
+
+    model_type = args['model_type']
+    # 区分常规机器学习模型和lgb,这里只实例化svr,后续可扩展
+    if model_type == "lightgbm":
+        model, y_pred = train_lgb([X_train, X_test, y_train, y_test], categorical_features, model_params,
+                                  num_boost_round, sample_weight=sample_weight)
+    elif model_type == "SVR":
+        model, y_pred = train_svr([X_train, X_test, y_train, y_test], model_params, sample_weight=sample_weight)
+    else:
+        raise ValueError(f"Invalid model_type, must be one of [lightgbm, SVR]")
+
+    # 评估
+    mse = mean_squared_error(y_test, y_pred)
+    rmse = np.sqrt(mse)
+    mae = mean_absolute_error(y_test, y_pred)
+    print(f'The test rmse is: {rmse},"The test mae is:"{mae}')
+    return model, features
+
+
+@app.route('/model_training_ml', methods=['POST'])
+def model_training_ml():
+    # 获取程序开始时间
+    start_time = time.time()
+    result = {}
+    success = 0
+    print("Program starts execution!")
+    try:
+        args = request.values.to_dict()
+        print('args', args)
+        logger.info(args)
+        power_df = get_data_from_mongo(args)
+        model, features = build_model(power_df, args)
+        insert_pickle_model_into_mongo(model, args, features=features)
+        success = 1
+    except Exception as e:
+        my_exception = traceback.format_exc()
+        my_exception.replace("\n", "\t")
+        result['msg'] = my_exception
+    end_time = time.time()
+
+    result['success'] = success
+    result['args'] = args
+    result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
+    result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
+    print("Program execution ends!")
+    return result
+
+
+if __name__ == "__main__":
+    print("Program starts execution!")
+    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    logger = logging.getLogger("model_training_ml log")
+    from waitress import serve
+
+    serve(app, host="0.0.0.0", port=10125)
+    print("server start!")