ソースを参照

Merge branch 'refs/heads/dev_awg' into dev_hzh

# Conflicts:
#	common/database_dml.py
#	run_all.py
hzh 2 週間 前
コミット
e9710dac23
43 ファイル変更1631 行追加937 行削除
  1. 14 6
      Dockerfile
  2. 1 1
      common/data_cleaning.py
  3. 29 0
      common/data_utils.py
  4. 15 0
      common/database.toml
  5. 571 95
      common/database_dml.py
  6. 0 612
      common/database_dml_koi.py
  7. BIN
      common/jar/hive-jdbc-standalone.jar
  8. 234 0
      data_processing/data_operation/custom_data_handler.py
  9. 80 11
      data_processing/data_operation/data_nwp_ftp.py
  10. 54 0
      data_processing/data_operation/hive_to_mongo.py
  11. 6 54
      evaluation_processing/evaluation_accuracy.py
  12. 6 21
      models_processing/model_predict/model_prediction_lightgbm.py
  13. 96 0
      models_processing/model_predict/model_prediction_photovoltaic_physical.py
  14. 37 6
      models_processing/model_tf/losses.py
  15. 1 1
      models_processing/model_tf/lstm.yaml
  16. 1 1
      models_processing/model_tf/tf_bilstm.py
  17. 1 1
      models_processing/model_tf/tf_bilstm_2.py
  18. 1 1
      models_processing/model_tf/tf_bp.py
  19. 4 3
      models_processing/model_tf/tf_bp_pre.py
  20. 3 2
      models_processing/model_tf/tf_bp_train.py
  21. 1 1
      models_processing/model_tf/tf_cnn.py
  22. 4 3
      models_processing/model_tf/tf_cnn_pre.py
  23. 3 2
      models_processing/model_tf/tf_cnn_train.py
  24. 1 1
      models_processing/model_tf/tf_lstm.py
  25. 4 4
      models_processing/model_tf/tf_lstm2_pre.py
  26. 4 3
      models_processing/model_tf/tf_lstm2_train.py
  27. 4 4
      models_processing/model_tf/tf_lstm3_pre.py
  28. 3 2
      models_processing/model_tf/tf_lstm3_train.py
  29. 4 3
      models_processing/model_tf/tf_lstm_pre.py
  30. 3 2
      models_processing/model_tf/tf_lstm_train.py
  31. 1 1
      models_processing/model_tf/tf_lstm_zone.py
  32. 4 4
      models_processing/model_tf/tf_lstm_zone_pre.py
  33. 3 2
      models_processing/model_tf/tf_lstm_zone_train.py
  34. 143 0
      models_processing/model_tf/tf_multi_nwp_pre.py
  35. 123 0
      models_processing/model_tf/tf_multi_nwp_train.py
  36. 1 1
      models_processing/model_tf/tf_tcn.py
  37. 1 1
      models_processing/model_tf/tf_test.py
  38. 4 3
      models_processing/model_tf/tf_test_pre.py
  39. 3 2
      models_processing/model_tf/tf_test_train.py
  40. 1 1
      models_processing/model_tf/tf_transformer.py
  41. 156 75
      post_processing/cdq_coe_gen.py
  42. 4 1
      requirements.txt
  43. 2 6
      run_all.py

+ 14 - 6
Dockerfile

@@ -1,14 +1,22 @@
+# 使用官方 Python 镜像作为基础镜像
+FROM 192.168.1.36:5000/python:3.12
 
-FROM ubuntu:latest
-# 安装 tzdata 包并设置时区
-RUN apt-get update && apt-get install -y tzdata
 ENV TZ=Asia/Shanghai
-RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
+RUN apt-get update && \
+    apt-get install -y tzdata openjdk-17-jdk && \
+    ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && \
+    echo $TZ > /etc/timezone && \
+    DEBIAN_FRONTEND=noninteractive dpkg-reconfigure tzdata && \
+    apt-get clean && \
+    rm -rf /var/lib/apt/lists/*
+
 
-# 使用官方 Python 镜像作为基础镜像
-FROM 192.168.1.36:5000/python:3.12
 
 ENV LANG=en_US.UTF-8
+ENV JAVA_HOME=/usr/lib/jvm/java-17-openjdk-amd64
+ENV PATH="$JAVA_HOME/bin:$PATH"
+ENV CLASSPATH=.:$JAVA_HOME/lib:$JAVA_HOME/lib/tools.jar
+
 # 设置工作目录
 WORKDIR /app
 

+ 1 - 1
common/data_cleaning.py

@@ -31,7 +31,7 @@ def data_column_cleaning(data, logger, clean_value=[-99.0, -99]):
     for val in clean_value:
         data1 = data1.replace(val, np.nan)
     # nan 列超过80% 删除
-    data1 = data1.dropna(axis=1, thresh=len(data) * 0.8)
+    data1 = data1.dropna(axis=1, thresh=len(data) * 0.5)
     # 删除取值全部相同的列
     data1 = data1.loc[:, (data1 != data1.iloc[0]).any()]
     data = data[data1.columns.tolist()]

+ 29 - 0
common/data_utils.py

@@ -0,0 +1,29 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# @FileName  :data_utils.py
+# @Time      :2025/5/21 16:16
+# @Author    :David
+# @Company: shenyang JY
+
+
+
+def deep_update(target, source):
+    """
+    递归将 source 字典的内容合并到 target 字典中
+    规则:
+      1. 若 key 在 target 和 source 中都存在,且值均为字典 → 递归合并
+      2. 若 key 在 source 中存在但 target 不存在 → 直接添加
+      3. 若 key 在 source 中存在且类型不为字典 → 覆盖 target 的值
+    """
+    for key, value in source.items():
+        # 如果 target 中存在该 key 且双方值都是字典 → 递归合并
+        if key in target and isinstance(target[key], dict) and isinstance(value, dict):
+            deep_update(target[key], value)
+        else:
+            # 直接覆盖或添加(包括非字典类型或 target 中不存在该 key 的情况)
+            target[key] = value
+    return target
+
+
+if __name__ == "__main__":
+    run_code = 0

+ 15 - 0
common/database.toml

@@ -0,0 +1,15 @@
+[mongodb]
+mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
+[hive]
+jdbc_url = "jdbc:hive2://basicserver1:2181,basicserver2:2181,basicserver3:2181,basicserver4:2181,basicserver5:2181/;serviceDiscoveryMode=zooKeeper;zooKeeperNamespace=hiveserver2"
+driver_class = "org.apache.hive.jdbc.HiveDriver"
+user = ""
+password = ""
+jar_file = 'jar/hive-jdbc-standalone.jar'
+
+[xmo]
+features = ['stationCode', 'date_time', 'forecastDatatime', 'rh',  'wd80', 'ws10', 'ws80', 'dniCalcd','rain', 'dewPoint2m', 'snowfall', 'windDirection10m', 'precipitation',
+           'apparentTemperature', 'weatherCode', 'sunshineDuration','shortwaveRadiation', 'directRadiation', 'diffuseRadiation','globalTiltedIrradiance', 'terrestrialRadiation']
+
+numeric_features = ['rh',  'wd80', 'ws10', 'ws80', 'dniCalcd','rain', 'dewPoint2m', 'snowfall', 'windDirection10m', 'precipitation',
+           'apparentTemperature', 'weatherCode', 'sunshineDuration','shortwaveRadiation', 'directRadiation', 'diffuseRadiation','globalTiltedIrradiance', 'terrestrialRadiation']

+ 571 - 95
common/database_dml.py

@@ -6,13 +6,26 @@ from sqlalchemy import create_engine
 import pickle
 from io import BytesIO
 import joblib
-import h5py
+import json
 import tensorflow as tf
 import os
 import tempfile
+import jaydebeapi
+import toml
+from typing import Dict, Any, Optional, Union, Tuple
+from datetime import datetime, timedelta
+
+# 读取 toml 配置文件
+current_dir = os.path.dirname(os.path.abspath(__file__))
+with open(os.path.join(current_dir, 'database.toml'), 'r', encoding='utf-8') as f:
+    config = toml.load(f)  # 只读的全局配置
+
+jar_file = os.path.join(current_dir, 'jar/hive-jdbc-standalone.jar')
+
 
 def get_data_from_mongo(args):
-    mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
+    # 获取 hive 配置部分
+    mongodb_connection = config['mongodb']['mongodb_connection']
     mongodb_database = args['mongodb_database']
     mongodb_read_table = args['mongodb_read_table']
     query_dict = {}
@@ -42,7 +55,9 @@ def get_data_from_mongo(args):
 
 
 def get_df_list_from_mongo(args):
-    mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',')
+    # 获取 hive 配置部分
+    mongodb_connection = config['mongodb']['mongodb_connection']
+    mongodb_database, mongodb_read_table = args['mongodb_database'], args['mongodb_read_table'].split(',')
     df_list = []
     client = MongoClient(mongodb_connection)
     # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
@@ -58,6 +73,7 @@ def get_df_list_from_mongo(args):
     client.close()
     return df_list
 
+
 def insert_data_into_mongo(res_df, args):
     """
     插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
@@ -68,7 +84,8 @@ def insert_data_into_mongo(res_df, args):
     - overwrite: 布尔值,True 表示覆盖,False 表示追加
     - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
     """
-    mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
+    # 获取 hive 配置部分
+    mongodb_connection = config['mongodb']['mongodb_connection']
     mongodb_database = args['mongodb_database']
     mongodb_write_table = args['mongodb_write_table']
     overwrite = 1
@@ -119,7 +136,7 @@ def insert_data_into_mongo(res_df, args):
 def get_data_fromMysql(params):
     mysql_conn = params['mysql_conn']
     query_sql = params['query_sql']
-    #数据库读取实测气象
+    # 数据库读取实测气象
     engine = create_engine(f"mysql+pymysql://{mysql_conn}")
     # 定义SQL查询
     with engine.connect() as conn:
@@ -127,9 +144,11 @@ def get_data_fromMysql(params):
     return df
 
 
-def insert_pickle_model_into_mongo(model, args, features=None):
-    mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
-    args['mongodb_database'], args['mongodb_write_table'], args['model_name']
+def insert_pickle_model_into_mongo(model, args):
+    # 获取 hive 配置部分
+    mongodb_connection = config['mongodb']['mongodb_connection']
+    mongodb_database, mongodb_write_table, model_name = args['mongodb_database'], args['mongodb_write_table'], args[
+        'model_name']
     client = MongoClient(mongodb_connection)
     db = client[mongodb_database]
     # 序列化模型
@@ -138,9 +157,6 @@ def insert_pickle_model_into_mongo(model, args, features=None):
         'model_name': model_name,
         'model': model_bytes,  # 将模型字节流存入数据库
     }
-    # 保存模型特征
-    if features is not None:
-        model_data['features'] = features
     print('Training completed!')
 
     if mongodb_write_table in db.list_collection_names():
@@ -152,9 +168,27 @@ def insert_pickle_model_into_mongo(model, args, features=None):
     print("model inserted successfully!")
 
 
-def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,args):
-    mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
-                                args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
+def get_pickle_model_from_mongo(args):
+    mongodb_connection = config['mongodb']['mongodb_connection']
+    mongodb_database, mongodb_model_table, model_name = args['mongodb_database'], args['mongodb_model_table'], args['model_name']
+    client = MongoClient(mongodb_connection)
+    db = client[mongodb_database]
+    collection = db[mongodb_model_table]
+    model_data = collection.find_one({"model_name": model_name})
+    if model_data is not None:
+        model_binary = model_data['model']  # 确保这个字段是存储模型的二进制数据
+        # 反序列化模型
+        model = pickle.loads(model_binary)
+        return model
+    else:
+        return None
+
+
+def insert_h5_model_into_mongo(model, feature_scaler_bytes, target_scaler_bytes, args):
+    # 获取 hive 配置部分
+    mongodb_connection = config['mongodb']['mongodb_connection']
+    mongodb_database, scaler_table, model_table, model_name = args['mongodb_database'], args['scaler_table'], args[
+        'model_table'], args['model_name']
     client = MongoClient(mongodb_connection)
     db = client[mongodb_database]
     if scaler_table in db.list_collection_names():
@@ -212,65 +246,341 @@ def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,a
                 print(f"⚠️ 临时文件清理失败: {temp_path}")
 
 
-# def insert_trained_model_into_mongo(model, args):
-#     mongodb_connection,mongodb_database,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
-#                                 args['mongodb_database'],args['model_table'],args['model_name'])
-#
-#     gen_time, params_json, descr = args['gen_time'], args['params'], args['descr']
-#     client = MongoClient(mongodb_connection)
-#     db = client[mongodb_database]
-#     if model_table in db.list_collection_names():
-#         db[model_table].drop()
-#         print(f"Collection '{model_table} already exist, deleted successfully!")
-#     model_table = db[model_table]
-#
-#     # 创建 BytesIO 缓冲区
-#     model_buffer = BytesIO()
-#     # 将模型保存为 HDF5 格式到内存 (BytesIO)
-#     model.save(model_buffer, save_format='h5')
-#     # 将指针移到缓冲区的起始位置
-#     model_buffer.seek(0)
-#     # 获取模型的二进制数据
-#     model_data = model_buffer.read()
-#     # 将模型保存到 MongoDB
-#     model_table.insert_one({
-#         "model_name": model_name,
-#         "model_data": model_data,
-#         "gen_time": gen_time,
-#         "params": params_json,
-#         "descr": descr
-#     })
-#     print("模型成功保存到 MongoDB!")
-
-def insert_scaler_model_into_mongo(feature_scaler_bytes, scaled_target_bytes, args):
-    mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
-                                args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
-    client = MongoClient(mongodb_connection)
-    db = client[mongodb_database]
-    if scaler_table in db.list_collection_names():
-        db[scaler_table].drop()
-        print(f"Collection '{scaler_table} already exist, deleted successfully!")
-    collection = db[scaler_table]  # 集合名称
-    # Save the scalers in MongoDB as binary data
-    collection.insert_one({
-        "feature_scaler": feature_scaler_bytes.read(),
-        "target_scaler": scaled_target_bytes.read()
-    })
-    client.close()
-    print("scaler_model inserted successfully!")
+def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any]) -> str:
+    """
+    将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型
+    参数:
+    model : keras模型 - 训练好的Keras模型
+    args : dict - 包含以下键的字典:
+        mongodb_database: 数据库名称
+        model_table: 集合名称
+        model_name: 模型名称
+        gen_time: 模型生成时间(datetime对象)
+        params: 模型参数(JSON可序列化对象)
+        descr: 模型描述文本
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'model_table', 'model_name',
+                     'gen_time', 'params', 'descr'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"缺少必要参数: {missing}")
 
+    # ------------------------- 配置解耦 -------------------------
+    # 从环境变量获取连接信息(更安全)
+    mongodb_connection = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
 
-def get_h5_model_from_mongo(args, custom=None):
-    mongodb_connection,mongodb_database,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['model_table'],args['model_name']
-    client = MongoClient(mongodb_connection)
-    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
-    db = client[mongodb_database]
-    collection = db[model_table]  # 集合名称
+    # ------------------------- 资源初始化 -------------------------
+    fd, temp_path = None, None
+    client = None
+
+    try:
+        # ------------------------- 临时文件处理 -------------------------
+        fd, temp_path = tempfile.mkstemp(suffix='.keras')
+        os.close(fd)  # 立即释放文件锁
+
+        # ------------------------- 模型保存 -------------------------
+        try:
+            model.save(temp_path)  # 不指定save_format,默认使用keras新格式
+        except Exception as e:
+            raise RuntimeError(f"模型保存失败: {str(e)}") from e
+
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(mongodb_connection)
+        db = client[args['mongodb_database']]
+        collection = db[args['model_table']]
+
+        # ------------------------- 索引检查 -------------------------
+        # index_info = collection.index_information()
+        # if "gen_time_1" not in index_info:
+        #     print("开始创建索引...")
+        #     collection.create_index(
+        #         [("gen_time", ASCENDING)],
+        #         name="gen_time_1",
+        #         background=True
+        #     )
+        #     print("索引创建成功")
+        # else:
+        #     print("索引已存在,跳过创建")
+
+        # ------------------------- 容量控制 -------------------------
+        # 使用更高效的计数方式
+        if collection.estimated_document_count() >= 50:
+            # 原子性删除操作
+            if deleted := collection.find_one_and_delete(
+                    sort=[("gen_time", ASCENDING)],
+                    projection={"_id": 1, "model_name": 1, "gen_time": 1}
+            ):
+                print(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
 
-     # 查询 MongoDB 获取模型数据
-    model_doc = collection.find_one({"model_name": model_name})
-    if model_doc:
-        model_data = model_doc['model_data']  # 获取模型的二进制数据
+        # ------------------------- 数据插入 -------------------------
+        with open(temp_path, 'rb') as f:
+            result = collection.insert_one({
+                "model_name": args['model_name'],
+                "model_data": f.read(),
+                "gen_time": args['gen_time'],
+                "params": args['params'],
+                "descr": args['descr']
+            })
+
+        print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
+        return str(result.inserted_id)
+
+    except Exception as e:
+        # ------------------------- 异常分类处理 -------------------------
+        error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
+        print(f"❌ {error_type} - 详细错误: {str(e)}")
+        raise  # 根据业务需求决定是否重新抛出
+
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+        if temp_path and os.path.exists(temp_path):
+            try:
+                os.remove(temp_path)
+            except PermissionError:
+                print(f"⚠️ 临时文件清理失败: {temp_path}")
+
+
+def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO,
+                                   args: Dict[str, Any]) -> str:
+    """
+    将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档
+
+    参数:
+    feature_scaler_bytes: BytesIO - 特征缩放器字节流
+    scaled_target_bytes: BytesIO - 目标缩放器字节流
+    args : dict - 包含以下键的字典:
+        mongodb_database: 数据库名称
+        scaler_table: 集合名称
+        model_name: 关联模型名称
+        gen_time: 生成时间(datetime对象)
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"缺少必要参数: {missing}")
+
+    # ------------------------- 配置解耦 -------------------------
+    # 从环境变量获取连接信息(安全隔离凭证)
+    mongodb_conn = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
+
+    # ------------------------- 输入验证 -------------------------
+    for buf, name in [(feature_scaler_bytes, "特征缩放器"),
+                      (target_scaler_bytes, "目标缩放器")]:
+        if not isinstance(buf, BytesIO):
+            raise TypeError(f"{name} 必须为BytesIO类型")
+        if buf.getbuffer().nbytes == 0:
+            raise ValueError(f"{name} 字节流为空")
+
+    client = None
+    try:
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(mongodb_conn)
+        db = client[args['mongodb_database']]
+        collection = db[args['scaler_table']]
+
+        # ------------------------- 索引维护 -------------------------
+        # if "gen_time_1" not in collection.index_information():
+        #     collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
+        #     print("⏱️ 已创建时间排序索引")
+
+        # ------------------------- 容量控制 -------------------------
+        # 使用近似计数提升性能(误差在几十条内可接受)
+        if collection.estimated_document_count() >= 50:
+            # 原子性删除操作(保证事务完整性)
+            if deleted := collection.find_one_and_delete(
+                    sort=[("gen_time", ASCENDING)],
+                    projection={"_id": 1, "model_name": 1, "gen_time": 1}
+            ):
+                print(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
+
+        # ------------------------- 数据插入 -------------------------
+        # 确保字节流指针位置正确
+        feature_scaler_bytes.seek(0)
+        target_scaler_bytes.seek(0)
+
+        result = collection.insert_one({
+            "model_name": args['model_name'],
+            "gen_time": args['gen_time'],
+            "feature_scaler": feature_scaler_bytes.read(),
+            "target_scaler": target_scaler_bytes.read()
+        })
+
+        print(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
+        return str(result.inserted_id)
+
+    except Exception as e:
+        # ------------------------- 异常分类处理 -------------------------
+        error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误"
+        print(f"❌ {error_type}异常 - 详细错误: {str(e)}")
+        raise  # 根据业务需求决定是否重新抛出
+
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+        # 重置字节流指针(确保后续可复用)
+        feature_scaler_bytes.seek(0)
+        target_scaler_bytes.seek(0)
+
+
+def get_h5_model_from_mongo(args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[
+    tf.keras.Model]:
+    """
+    从MongoDB获取指定模型的最新版本
+
+    参数:
+    args : dict - 包含以下键的字典:
+        mongodb_database: 数据库名称
+        model_table: 集合名称
+        model_name: 要获取的模型名称
+    custom_objects: dict - 自定义Keras对象字典
+
+    返回:
+    tf.keras.Model - 加载成功的Keras模型
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'model_table', 'model_name'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"❌ 缺失必要参数: {missing}")
+
+    # ------------------------- 环境配置 -------------------------
+    mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
+    client = None
+    tmp_file_path = None  # 用于跟踪临时文件路径
+    try:
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(
+            mongo_uri,
+            maxPoolSize=10,  # 连接池优化
+            socketTimeoutMS=5000
+        )
+        db = client[args['mongodb_database']]
+        collection = db[args['model_table']]
+
+        # ------------------------- 索引维护 -------------------------
+        index_name = "model_gen_time_idx"
+        if index_name not in collection.index_information():
+            collection.create_index(
+                [("model_name", 1), ("gen_time", DESCENDING)],
+                name=index_name
+            )
+            print("⏱️ 已创建复合索引")
+
+        # ------------------------- 高效查询 -------------------------
+        model_doc = collection.find_one(
+            {"model_name": args['model_name']},
+            sort=[('gen_time', DESCENDING)],
+            projection={"model_data": 1, "gen_time": 1}  # 获取必要字段
+        )
+
+        if not model_doc:
+            print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
+            return None
+
+        # ------------------------- 内存优化加载 -------------------------
+        if model_doc:
+            model_data = model_doc['model_data']  # 获取模型的二进制数据
+            # # 将二进制数据加载到 BytesIO 缓冲区
+            # model_buffer = BytesIO(model_data)
+            # # 确保指针在起始位置
+            # model_buffer.seek(0)
+            # # 从缓冲区加载模型
+            # # 使用 h5py 和 BytesIO 从内存中加载模型
+            # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
+            #     model = tf.keras.models.load_model(f, custom_objects=custom_objects)
+            # 创建临时文件
+            with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
+                tmp_file.write(model_data)
+                tmp_file_path = tmp_file.name  # 获取临时文件路径
+
+            # 从临时文件加载模型
+            model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects)
+
+            print(f"{args['model_name']}模型成功从 MongoDB 加载!")
+            return model
+    except tf.errors.NotFoundError as e:
+        print(f"❌ 模型结构缺失关键组件: {str(e)}")
+        raise RuntimeError("模型架构不完整") from e
+
+    except Exception as e:
+        print(f"❌ 系统异常: {str(e)}")
+        raise
+
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+        # 确保删除临时文件
+        if tmp_file_path and os.path.exists(tmp_file_path):
+            try:
+                os.remove(tmp_file_path)
+                print(f"🧹 已清理临时文件: {tmp_file_path}")
+            except Exception as cleanup_err:
+                print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
+
+
+def get_keras_model_from_mongo(
+        args: Dict[str, Any],
+        custom_objects: Optional[Dict[str, Any]] = None
+) -> Optional[tf.keras.Model]:
+    """
+    从MongoDB获取指定模型的最新版本(支持Keras格式)
+
+    参数:
+    args : dict - 包含以下键的字典:
+        mongodb_database: 数据库名称
+        model_table: 集合名称
+        model_name: 要获取的模型名称
+    custom_objects: dict - 自定义Keras对象字典
+
+    返回:
+    tf.keras.Model - 加载成功的Keras模型
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'model_table', 'model_name'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"❌ 缺失必要参数: {missing}")
+
+    # ------------------------- 环境配置 -------------------------
+    mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
+    client = None
+    tmp_file_path = None  # 用于跟踪临时文件路径
+
+    try:
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(
+            mongo_uri,
+            maxPoolSize=10,
+            socketTimeoutMS=5000
+        )
+        db = client[args['mongodb_database']]
+        collection = db[args['model_table']]
+
+        # ------------------------- 索引维护 -------------------------
+        # index_name = "model_gen_time_idx"
+        # if index_name not in collection.index_information():
+        #     collection.create_index(
+        #         [("model_name", 1), ("gen_time", DESCENDING)],
+        #         name=index_name
+        #     )
+        #     print("⏱️ 已创建复合索引")
+
+        # ------------------------- 高效查询 -------------------------
+        model_doc = collection.find_one(
+            {"model_name": args['model_name']},
+            sort=[('gen_time', DESCENDING)],
+            projection={"model_data": 1, "gen_time": 1, 'params': 1}
+        )
+
+        if not model_doc:
+            print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
+            return None
+
+        # ------------------------- 内存优化加载 -------------------------
+        model_data = model_doc['model_data']
+        model_params = model_doc['params']
         # 创建临时文件(自动删除)
         with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
             tmp_file.write(model_data)
@@ -279,11 +589,25 @@ def get_h5_model_from_mongo(args, custom=None):
         # 从临时文件加载模型
         model = tf.keras.models.load_model(
             tmp_file_path,
-            custom_objects=custom
+            custom_objects=custom_objects
         )
 
         print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
-        client.close()
+        return model, model_params
+
+    except tf.errors.NotFoundError as e:
+        print(f"❌ 模型结构缺失关键组件: {str(e)}")
+        raise RuntimeError("模型架构不完整") from e
+
+    except Exception as e:
+        print(f"❌ 系统异常: {str(e)}")
+        raise
+
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+
         # 确保删除临时文件
         if tmp_file_path and os.path.exists(tmp_file_path):
             try:
@@ -291,28 +615,180 @@ def get_h5_model_from_mongo(args, custom=None):
                 print(f"🧹 已清理临时文件: {tmp_file_path}")
             except Exception as cleanup_err:
                 print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
-        return model
-    else:
-        print(f"未找到model_name为 {model_name} 的模型。")
-        client.close()
-        return None
 
 
-def get_scaler_model_from_mongo(args, only_feature_scaler=False):
-    mongodb_connection, mongodb_database, scaler_table, = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'], args['scaler_table'])
-    client = MongoClient(mongodb_connection)
-    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
-    db = client[mongodb_database]
-    collection = db[scaler_table]  # 集合名称
-    # Retrieve the scalers from MongoDB
-    scaler_doc = collection.find_one()
-    # Deserialize the scalers
-
-    feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"])
-    feature_scaler = joblib.load(feature_scaler_bytes)
-    if only_feature_scaler:
-        return feature_scaler
-    target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
-    target_scaler = joblib.load(target_scaler_bytes)
-    client.close()
-    return feature_scaler,target_scaler
+def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[
+    object, Tuple[object, object]]:
+    """
+    优化版特征缩放器加载函数 - 安全高效获取最新预处理模型
+
+    参数:
+    args : 必须包含键:
+        - mongodb_database: 数据库名称
+        - scaler_table: 集合名称
+        - model_name: 目标模型名称
+    only_feature_scaler : 是否仅返回特征缩放器
+
+    返回:
+    单个缩放器对象或(feature_scaler, target_scaler)元组
+
+    异常:
+    ValueError : 参数缺失或类型错误
+    RuntimeError : 数据操作异常
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'scaler_table', 'model_name'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"❌ 缺失必要参数: {missing}")
+
+    # ------------------------- 环境配置 -------------------------
+    mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
+
+    client = None
+    try:
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(
+            mongo_uri,
+            maxPoolSize=20,  # 连接池上限
+            socketTimeoutMS=3000,  # 3秒超时
+            serverSelectionTimeoutMS=5000  # 5秒服务器选择超时
+        )
+        db = client[args['mongodb_database']]
+        collection = db[args['scaler_table']]
+
+        # ------------------------- 索引维护 -------------------------
+        # index_name = "model_gen_time_idx"
+        # if index_name not in collection.index_information():
+        #     collection.create_index(
+        #         [("model_name", 1), ("gen_time", DESCENDING)],
+        #         name=index_name,
+        #         background=True  # 后台构建避免阻塞
+        #     )
+        #     print("⏱️ 已创建特征缩放器复合索引")
+
+        # ------------------------- 高效查询 -------------------------
+        scaler_doc = collection.find_one(
+            {"model_name": args['model_name']},
+            sort=[('gen_time', DESCENDING)],
+            projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1}
+        )
+
+        if not scaler_doc:
+            raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录")
+
+        # ------------------------- 反序列化处理 -------------------------
+        def load_scaler(data: bytes) -> object:
+            """安全加载序列化对象"""
+            with BytesIO(data) as buffer:
+                buffer.seek(0)  # 确保指针复位
+                try:
+                    return joblib.load(buffer)
+                except joblib.UnpicklingError as e:
+                    raise RuntimeError("反序列化失败 (可能版本不兼容)") from e
+
+        # 特征缩放器加载
+        feature_data = scaler_doc["feature_scaler"]
+        if not isinstance(feature_data, bytes):
+            raise RuntimeError("特征缩放器数据格式异常")
+        feature_scaler = load_scaler(feature_data)
+
+        if only_feature_scaler:
+            return feature_scaler
+
+        # 目标缩放器加载
+        target_data = scaler_doc["target_scaler"]
+        if not isinstance(target_data, bytes):
+            raise RuntimeError("目标缩放器数据格式异常")
+        target_scaler = load_scaler(target_data)
+
+        print(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})")
+        return feature_scaler, target_scaler
+
+    except PyMongoError as e:
+        raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e
+    except RuntimeError as e:
+        raise RuntimeError(f"🔌 mongo操作失败: {str(e)}") from e  # 直接传递已封装的异常
+    except Exception as e:
+        raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+
+
+def normalize_key(s):
+    return s.lower()
+
+
+def get_xmo_data_from_hive(args):
+    # 获取 hive 配置部分
+    hive_config = config['hive']
+    jdbc_url = hive_config['jdbc_url']
+    driver_class = hive_config['driver_class']
+    user = hive_config['user']
+    password = hive_config['password']
+    features = config['xmo']['features']
+    numeric_features = config['xmo']['numeric_features']
+    if 'moment' not in args or 'farm_id' not in args:
+        msg_error = 'One or more of the following parameters are missing: moment, farm_id!'
+        return msg_error
+    else:
+        moment = args['moment']
+        farm_id = args['farm_id']
+
+        if 'current_date' in args:
+            current_date = datetime.strptime(args['current_date'], "%Y%m%d")
+        else:
+            current_date = datetime.now()
+        if 'days' in args:
+            days = int(args['days']) + 1
+        else:
+            days = 1
+        json_feature = f"nwp_xmo_{moment}"
+        # 建立连接
+        conn = jaydebeapi.connect(driver_class, jdbc_url, [user, password], jar_file)
+        # 查询 Hive 表
+        cursor = conn.cursor()
+        query_sql = ""
+        for i in range(0, days):
+            sysdate_pre = (current_date + timedelta(days=i)).strftime("%Y%m%d")
+            if i == 0:
+                pass
+            else:
+                query_sql += "union \n"
+
+            query_sql += """select rowkey,datatimestamp,{2} from hbase_forecast.forecast_xmo_d{3} 
+                                                     where rowkey>='{0}-{1}0000' and rowkey<='{0}-{1}2345' \n""".format(
+                farm_id, sysdate_pre, json_feature, i)
+        print("query_sql\n", query_sql)
+        cursor.execute(query_sql)
+        # 获取列名
+        columns = [desc[0] for desc in cursor.description]
+        # 获取所有数据
+        rows = cursor.fetchall()
+        # 转成 DataFrame
+        df = pd.DataFrame(rows, columns=columns)
+        cursor.close()
+        conn.close()
+        df[json_feature] = df[json_feature].apply(lambda x: json.loads(x) if isinstance(x, str) else x)
+        df_features = pd.json_normalize(df[json_feature])
+        if 'forecastDatatime' not in df_features.columns:
+            return "The returned data does not contain the forecastDatetime column — the data might be empty or null!"
+        else:
+            df_features['date_time'] = pd.to_datetime(df_features['forecastDatatime'], unit='ms', utc=True).dt.tz_convert(
+                'Asia/Shanghai').dt.strftime('%Y-%m-%d %H:%M:%S')
+            df_features[numeric_features] = df_features[numeric_features].apply(pd.to_numeric, errors='coerce')
+            return df_features[features]
+
+
+if __name__ == "__main__":
+    print("Program starts execution!")
+    args = {
+        'moment': '06',
+        'current_date': '20250609',
+        'farm_id': 'J00883',
+        'days': '13'
+    }
+    df = get_xmo_data_from_hive(args)
+    print(df.head(2),df.shape)
+    print("server start!")

+ 0 - 612
common/database_dml_koi.py

@@ -1,612 +0,0 @@
-import pymongo
-from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING
-from pymongo.errors import PyMongoError
-import pandas as pd
-from sqlalchemy import create_engine
-import pickle
-from io import BytesIO
-import joblib
-import h5py, os, io
-import tensorflow as tf
-from typing import Dict, Any, Optional, Union, Tuple
-import tempfile
-
-def get_data_from_mongo(args):
-    mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
-    mongodb_database = args['mongodb_database']
-    mongodb_read_table = args['mongodb_read_table']
-    query_dict = {}
-    if 'timeBegin' in args.keys():
-        timeBegin = args['timeBegin']
-        query_dict.update({"$gte": timeBegin})
-    if 'timeEnd' in args.keys():
-        timeEnd = args['timeEnd']
-        query_dict.update({"$lte": timeEnd})
-
-    client = MongoClient(mongodb_connection)
-    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
-    db = client[mongodb_database]
-    collection = db[mongodb_read_table]  # 集合名称
-    if len(query_dict) != 0:
-        query = {"dateTime": query_dict}
-        cursor = collection.find(query)
-    else:
-        cursor = collection.find()
-    data = list(cursor)
-    df = pd.DataFrame(data)
-    # 4. 删除 _id 字段(可选)
-    if '_id' in df.columns:
-        df = df.drop(columns=['_id'])
-    client.close()
-    return df
-
-
-def get_df_list_from_mongo(args):
-    mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',')
-    df_list = []
-    client = MongoClient(mongodb_connection)
-    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
-    db = client[mongodb_database]
-    for table in mongodb_read_table:
-        collection = db[table]  # 集合名称
-        data_from_db = collection.find()  # 这会返回一个游标(cursor)
-        # 将游标转换为列表,并创建 pandas DataFrame
-        df = pd.DataFrame(list(data_from_db))
-        if '_id' in df.columns:
-            df = df.drop(columns=['_id'])
-        df_list.append(df)
-    client.close()
-    return df_list
-
-def insert_data_into_mongo(res_df, args):
-    """
-    插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
-
-    参数:
-    - res_df: 要插入的 DataFrame 数据
-    - args: 包含 MongoDB 数据库和集合名称的字典
-    - overwrite: 布尔值,True 表示覆盖,False 表示追加
-    - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
-    """
-    mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
-    mongodb_database = args['mongodb_database']
-    mongodb_write_table = args['mongodb_write_table']
-    overwrite = 1
-    update_keys = None
-    if 'overwrite' in args.keys():
-        overwrite = int(args['overwrite'])
-    if 'update_keys' in args.keys():
-        update_keys = args['update_keys'].split(',')
-
-    client = MongoClient(mongodb_connection)
-    db = client[mongodb_database]
-    collection = db[mongodb_write_table]
-
-    # 覆盖模式:删除现有集合
-    if overwrite:
-        if mongodb_write_table in db.list_collection_names():
-            collection.drop()
-            print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!")
-
-    # 将 DataFrame 转为字典格式
-    data_dict = res_df.to_dict("records")  # 每一行作为一个字典
-
-    # 如果没有数据,直接返回
-    if not data_dict:
-        print("No data to insert.")
-        return
-
-    # 如果指定了 update_keys,则执行 upsert(更新或插入)
-    if update_keys and not overwrite:
-        operations = []
-        for record in data_dict:
-            # 构建查询条件,用于匹配要更新的文档
-            query = {key: record[key] for key in update_keys}
-            operations.append(UpdateOne(query, {'$set': record}, upsert=True))
-
-        # 批量执行更新/插入操作
-        if operations:
-            result = collection.bulk_write(operations)
-            print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
-    else:
-        # 追加模式:直接插入新数据
-        collection.insert_many(data_dict)
-        print("Data inserted successfully!")
-
-
-def get_data_fromMysql(params):
-    mysql_conn = params['mysql_conn']
-    query_sql = params['query_sql']
-    #数据库读取实测气象
-    engine = create_engine(f"mysql+pymysql://{mysql_conn}")
-    # 定义SQL查询
-    env_df = pd.read_sql_query(query_sql, engine)
-    return env_df
-
-
-def insert_pickle_model_into_mongo(model, args):
-    mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
-    args['mongodb_database'], args['mongodb_write_table'], args['model_name']
-    client = MongoClient(mongodb_connection)
-    db = client[mongodb_database]
-    # 序列化模型
-    model_bytes = pickle.dumps(model)
-    model_data = {
-        'model_name': model_name,
-        'model': model_bytes,  # 将模型字节流存入数据库
-    }
-    print('Training completed!')
-
-    if mongodb_write_table in db.list_collection_names():
-        db[mongodb_write_table].drop()
-        print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
-    collection = db[mongodb_write_table]  # 集合名称
-    collection.insert_one(model_data)
-    print("model inserted successfully!")
-
-
-def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any]) -> str:
-    """
-    将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型
-    参数:
-    model : keras模型 - 训练好的Keras模型
-    args : dict - 包含以下键的字典:
-        mongodb_database: 数据库名称
-        model_table: 集合名称
-        model_name: 模型名称
-        gen_time: 模型生成时间(datetime对象)
-        params: 模型参数(JSON可序列化对象)
-        descr: 模型描述文本
-    """
-    # ------------------------- 参数校验 -------------------------
-    required_keys = {'mongodb_database', 'model_table', 'model_name',
-                     'gen_time', 'params', 'descr'}
-    if missing := required_keys - args.keys():
-        raise ValueError(f"缺少必要参数: {missing}")
-
-    # ------------------------- 配置解耦 -------------------------
-    # 从环境变量获取连接信息(更安全)
-    mongodb_connection = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
-
-    # ------------------------- 资源初始化 -------------------------
-    fd, temp_path = None, None
-    client = None
-
-    try:
-        # ------------------------- 临时文件处理 -------------------------
-        fd, temp_path = tempfile.mkstemp(suffix='.keras')
-        os.close(fd)  # 立即释放文件锁
-
-        # ------------------------- 模型保存 -------------------------
-        try:
-            model.save(temp_path) # 不指定save_format,默认使用keras新格式
-        except Exception as e:
-            raise RuntimeError(f"模型保存失败: {str(e)}") from e
-
-        # ------------------------- 数据库连接 -------------------------
-        client = MongoClient(mongodb_connection)
-        db = client[args['mongodb_database']]
-        collection = db[args['model_table']]
-
-        # ------------------------- 索引检查 -------------------------
-        # index_info = collection.index_information()
-        # if "gen_time_1" not in index_info:
-        #     print("开始创建索引...")
-        #     collection.create_index(
-        #         [("gen_time", ASCENDING)],
-        #         name="gen_time_1",
-        #         background=True
-        #     )
-        #     print("索引创建成功")
-        # else:
-        #     print("索引已存在,跳过创建")
-
-        # ------------------------- 容量控制 -------------------------
-        # 使用更高效的计数方式
-        if collection.estimated_document_count() >= 50:
-            # 原子性删除操作
-            if deleted := collection.find_one_and_delete(
-                    sort=[("gen_time", ASCENDING)],
-                    projection={"_id": 1, "model_name": 1, "gen_time": 1}
-            ):
-                print(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
-
-        # ------------------------- 数据插入 -------------------------
-        with open(temp_path, 'rb') as f:
-            result = collection.insert_one({
-                "model_name": args['model_name'],
-                "model_data": f.read(),
-                "gen_time": args['gen_time'],
-                "params": args['params'],
-                "descr": args['descr']
-            })
-
-        print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
-        return str(result.inserted_id)
-
-    except Exception as e:
-        # ------------------------- 异常分类处理 -------------------------
-        error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
-        print(f"❌ {error_type} - 详细错误: {str(e)}")
-        raise  # 根据业务需求决定是否重新抛出
-
-    finally:
-        # ------------------------- 资源清理 -------------------------
-        if client:
-            client.close()
-        if temp_path and os.path.exists(temp_path):
-            try:
-                os.remove(temp_path)
-            except PermissionError:
-                print(f"⚠️ 临时文件清理失败: {temp_path}")
-
-
-def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO, args: Dict[str, Any]) -> str:
-    """
-    将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档
-
-    参数:
-    feature_scaler_bytes: BytesIO - 特征缩放器字节流
-    scaled_target_bytes: BytesIO - 目标缩放器字节流
-    args : dict - 包含以下键的字典:
-        mongodb_database: 数据库名称
-        scaler_table: 集合名称
-        model_name: 关联模型名称
-        gen_time: 生成时间(datetime对象)
-    """
-    # ------------------------- 参数校验 -------------------------
-    required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'}
-    if missing := required_keys - args.keys():
-        raise ValueError(f"缺少必要参数: {missing}")
-
-    # ------------------------- 配置解耦 -------------------------
-    # 从环境变量获取连接信息(安全隔离凭证)
-    mongodb_conn = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
-
-    # ------------------------- 输入验证 -------------------------
-    for buf, name in [(feature_scaler_bytes, "特征缩放器"),
-                      (target_scaler_bytes, "目标缩放器")]:
-        if not isinstance(buf, BytesIO):
-            raise TypeError(f"{name} 必须为BytesIO类型")
-        if buf.getbuffer().nbytes == 0:
-            raise ValueError(f"{name} 字节流为空")
-
-    client = None
-    try:
-        # ------------------------- 数据库连接 -------------------------
-        client = MongoClient(mongodb_conn)
-        db = client[args['mongodb_database']]
-        collection = db[args['scaler_table']]
-
-        # ------------------------- 索引维护 -------------------------
-        # if "gen_time_1" not in collection.index_information():
-        #     collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
-        #     print("⏱️ 已创建时间排序索引")
-
-        # ------------------------- 容量控制 -------------------------
-        # 使用近似计数提升性能(误差在几十条内可接受)
-        if collection.estimated_document_count() >= 50:
-            # 原子性删除操作(保证事务完整性)
-            if deleted := collection.find_one_and_delete(
-                    sort=[("gen_time", ASCENDING)],
-                    projection={"_id": 1, "model_name": 1, "gen_time": 1}
-            ):
-                print(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
-
-        # ------------------------- 数据插入 -------------------------
-        # 确保字节流指针位置正确
-        feature_scaler_bytes.seek(0)
-        target_scaler_bytes.seek(0)
-
-        result = collection.insert_one({
-            "model_name": args['model_name'],
-            "gen_time": args['gen_time'],
-            "feature_scaler": feature_scaler_bytes.read(),
-            "target_scaler": target_scaler_bytes.read()
-        })
-
-        print(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
-        return str(result.inserted_id)
-
-    except Exception as e:
-        # ------------------------- 异常分类处理 -------------------------
-        error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误"
-        print(f"❌ {error_type}异常 - 详细错误: {str(e)}")
-        raise  # 根据业务需求决定是否重新抛出
-
-    finally:
-        # ------------------------- 资源清理 -------------------------
-        if client:
-            client.close()
-        # 重置字节流指针(确保后续可复用)
-        feature_scaler_bytes.seek(0)
-        target_scaler_bytes.seek(0)
-
-
-def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[tf.keras.Model]:
-    """
-    从MongoDB获取指定模型的最新版本
-
-    参数:
-    args : dict - 包含以下键的字典:
-        mongodb_database: 数据库名称
-        model_table: 集合名称
-        model_name: 要获取的模型名称
-    custom_objects: dict - 自定义Keras对象字典
-
-    返回:
-    tf.keras.Model - 加载成功的Keras模型
-    """
-    # ------------------------- 参数校验 -------------------------
-    required_keys = {'mongodb_database', 'model_table', 'model_name'}
-    if missing := required_keys - args.keys():
-        raise ValueError(f"❌ 缺失必要参数: {missing}")
-
-    # ------------------------- 环境配置 -------------------------
-    mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
-    client = None
-    tmp_file_path = None  # 用于跟踪临时文件路径
-    try:
-        # ------------------------- 数据库连接 -------------------------
-        client = MongoClient(
-            mongo_uri,
-            maxPoolSize=10,  # 连接池优化
-            socketTimeoutMS=5000
-        )
-        db = client[args['mongodb_database']]
-        collection = db[args['model_table']]
-
-        # ------------------------- 索引维护 -------------------------
-        index_name = "model_gen_time_idx"
-        if index_name not in collection.index_information():
-            collection.create_index(
-                [("model_name", 1), ("gen_time", DESCENDING)],
-                name=index_name
-            )
-            print("⏱️ 已创建复合索引")
-
-        # ------------------------- 高效查询 -------------------------
-        model_doc = collection.find_one(
-            {"model_name": args['model_name']},
-            sort=[('gen_time', DESCENDING)],
-            projection={"model_data": 1, "gen_time": 1}  # 获取必要字段
-        )
-
-        if not model_doc:
-            print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
-            return None
-
-        # ------------------------- 内存优化加载 -------------------------
-        if model_doc:
-            model_data = model_doc['model_data']  # 获取模型的二进制数据
-            # # 将二进制数据加载到 BytesIO 缓冲区
-            # model_buffer = BytesIO(model_data)
-            # # 确保指针在起始位置
-            # model_buffer.seek(0)
-            # # 从缓冲区加载模型
-            # # 使用 h5py 和 BytesIO 从内存中加载模型
-            # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
-            #     model = tf.keras.models.load_model(f, custom_objects=custom_objects)
-            # 创建临时文件
-            with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
-                tmp_file.write(model_data)
-                tmp_file_path = tmp_file.name  # 获取临时文件路径
-
-            # 从临时文件加载模型
-            model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects)
-
-            print(f"{args['model_name']}模型成功从 MongoDB 加载!")
-            return model
-    except tf.errors.NotFoundError as e:
-        print(f"❌ 模型结构缺失关键组件: {str(e)}")
-        raise RuntimeError("模型架构不完整") from e
-
-    except Exception as e:
-        print(f"❌ 系统异常: {str(e)}")
-        raise
-
-    finally:
-        # ------------------------- 资源清理 -------------------------
-        if client:
-            client.close()
-        # 确保删除临时文件
-        if tmp_file_path and os.path.exists(tmp_file_path):
-            try:
-                os.remove(tmp_file_path)
-                print(f"🧹 已清理临时文件: {tmp_file_path}")
-            except Exception as cleanup_err:
-                print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
-
-
-def get_keras_model_from_mongo(
-        args: Dict[str, Any],
-        custom_objects: Optional[Dict[str, Any]] = None
-) -> Optional[tf.keras.Model]:
-    """
-    从MongoDB获取指定模型的最新版本(支持Keras格式)
-
-    参数:
-    args : dict - 包含以下键的字典:
-        mongodb_database: 数据库名称
-        model_table: 集合名称
-        model_name: 要获取的模型名称
-    custom_objects: dict - 自定义Keras对象字典
-
-    返回:
-    tf.keras.Model - 加载成功的Keras模型
-    """
-    # ------------------------- 参数校验 -------------------------
-    required_keys = {'mongodb_database', 'model_table', 'model_name'}
-    if missing := required_keys - args.keys():
-        raise ValueError(f"❌ 缺失必要参数: {missing}")
-
-    # ------------------------- 环境配置 -------------------------
-    mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
-    client = None
-    tmp_file_path = None  # 用于跟踪临时文件路径
-
-    try:
-        # ------------------------- 数据库连接 -------------------------
-        client = MongoClient(
-            mongo_uri,
-            maxPoolSize=10,
-            socketTimeoutMS=5000
-        )
-        db = client[args['mongodb_database']]
-        collection = db[args['model_table']]
-
-        # ------------------------- 索引维护 -------------------------
-        # index_name = "model_gen_time_idx"
-        # if index_name not in collection.index_information():
-        #     collection.create_index(
-        #         [("model_name", 1), ("gen_time", DESCENDING)],
-        #         name=index_name
-        #     )
-        #     print("⏱️ 已创建复合索引")
-
-        # ------------------------- 高效查询 -------------------------
-        model_doc = collection.find_one(
-            {"model_name": args['model_name']},
-            sort=[('gen_time', DESCENDING)],
-            projection={"model_data": 1, "gen_time": 1, 'params':1}
-        )
-
-        if not model_doc:
-            print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
-            return None
-
-        # ------------------------- 内存优化加载 -------------------------
-        model_data = model_doc['model_data']
-        model_params = model_doc['params']
-        # 创建临时文件(自动删除)
-        with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
-            tmp_file.write(model_data)
-            tmp_file_path = tmp_file.name  # 记录文件路径
-
-        # 从临时文件加载模型
-        model = tf.keras.models.load_model(
-            tmp_file_path,
-            custom_objects=custom_objects
-        )
-
-        print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
-        return model, model_params
-
-    except tf.errors.NotFoundError as e:
-        print(f"❌ 模型结构缺失关键组件: {str(e)}")
-        raise RuntimeError("模型架构不完整") from e
-
-    except Exception as e:
-        print(f"❌ 系统异常: {str(e)}")
-        raise
-
-    finally:
-        # ------------------------- 资源清理 -------------------------
-        if client:
-            client.close()
-
-        # 确保删除临时文件
-        if tmp_file_path and os.path.exists(tmp_file_path):
-            try:
-                os.remove(tmp_file_path)
-                print(f"🧹 已清理临时文件: {tmp_file_path}")
-            except Exception as cleanup_err:
-                print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
-
-def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]:
-    """
-    优化版特征缩放器加载函数 - 安全高效获取最新预处理模型
-
-    参数:
-    args : 必须包含键:
-        - mongodb_database: 数据库名称
-        - scaler_table: 集合名称
-        - model_name: 目标模型名称
-    only_feature_scaler : 是否仅返回特征缩放器
-
-    返回:
-    单个缩放器对象或(feature_scaler, target_scaler)元组
-
-    异常:
-    ValueError : 参数缺失或类型错误
-    RuntimeError : 数据操作异常
-    """
-    # ------------------------- 参数校验 -------------------------
-    required_keys = {'mongodb_database', 'scaler_table', 'model_name'}
-    if missing := required_keys - args.keys():
-        raise ValueError(f"❌ 缺失必要参数: {missing}")
-
-    # ------------------------- 环境配置 -------------------------
-    mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
-
-    client = None
-    try:
-        # ------------------------- 数据库连接 -------------------------
-        client = MongoClient(
-            mongo_uri,
-            maxPoolSize=20,  # 连接池上限
-            socketTimeoutMS=3000,  # 3秒超时
-            serverSelectionTimeoutMS=5000  # 5秒服务器选择超时
-        )
-        db = client[args['mongodb_database']]
-        collection = db[args['scaler_table']]
-
-        # ------------------------- 索引维护 -------------------------
-        # index_name = "model_gen_time_idx"
-        # if index_name not in collection.index_information():
-        #     collection.create_index(
-        #         [("model_name", 1), ("gen_time", DESCENDING)],
-        #         name=index_name,
-        #         background=True  # 后台构建避免阻塞
-        #     )
-        #     print("⏱️ 已创建特征缩放器复合索引")
-
-        # ------------------------- 高效查询 -------------------------
-        scaler_doc = collection.find_one(
-            {"model_name": args['model_name']},
-            sort=[('gen_time', DESCENDING)],
-            projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1}
-        )
-
-        if not scaler_doc:
-            raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录")
-
-        # ------------------------- 反序列化处理 -------------------------
-        def load_scaler(data: bytes) -> object:
-            """安全加载序列化对象"""
-            with BytesIO(data) as buffer:
-                buffer.seek(0)  # 确保指针复位
-                try:
-                    return joblib.load(buffer)
-                except joblib.UnpicklingError as e:
-                    raise RuntimeError("反序列化失败 (可能版本不兼容)") from e
-
-        # 特征缩放器加载
-        feature_data = scaler_doc["feature_scaler"]
-        if not isinstance(feature_data, bytes):
-            raise RuntimeError("特征缩放器数据格式异常")
-        feature_scaler = load_scaler(feature_data)
-
-        if only_feature_scaler:
-            return feature_scaler
-
-        # 目标缩放器加载
-        target_data = scaler_doc["target_scaler"]
-        if not isinstance(target_data, bytes):
-            raise RuntimeError("目标缩放器数据格式异常")
-        target_scaler = load_scaler(target_data)
-
-        print(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})")
-        return feature_scaler, target_scaler
-
-    except PyMongoError as e:
-        raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e
-    except RuntimeError as e:
-        raise  RuntimeError(f"🔌 mongo操作失败: {str(e)}") from e# 直接传递已封装的异常
-    except Exception as e:
-        raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
-    finally:
-        # ------------------------- 资源清理 -------------------------
-        if client:
-            client.close()
-

BIN
common/jar/hive-jdbc-standalone.jar


+ 234 - 0
data_processing/data_operation/custom_data_handler.py

@@ -262,3 +262,237 @@ class CustomDataHandler(object):
         pre_data.loc[:, features] = scaled_pre_data
         pre_x = self.get_predict_data([pre_data], time_series)
         return pre_x, data
+
+class MultiNwpDataHandler(object):
+    def __init__(self, logger, args):
+        self.logger = logger
+        self.opt = argparse.Namespace(**args)
+
+    def get_train_data(self, dfs, col_time, target, time_series=1):
+        train_x, valid_x, train_y, valid_y = [], [], [], []
+        for i, df in enumerate(dfs, start=1):
+            if len(df) < self.opt.Model["time_step"]:
+                self.logger.info("特征处理-训练数据-不满足time_step")
+
+            datax, datay = self.get_timestep_features(df, col_time, target, is_train=True, time_series=time_series)
+            if len(datax) < 10:
+                self.logger.info("特征处理-训练数据-无法进行最小分割")
+                continue
+            tx, vx, ty, vy = self.train_valid_split(datax, datay, valid_rate=self.opt.Model["valid_data_rate"], shuffle=self.opt.Model['shuffle_train_data'])
+            train_x.extend(tx)
+            valid_x.extend(vx)
+            train_y.extend(ty)
+            valid_y.extend(vy)
+
+        train_x = [np.array([x[0].values for x in train_x]), np.array([x[1].values for x in train_x])]
+        valid_x = [np.array([x[0].values for x in valid_x]), np.array([x[1].values for x in valid_x])]
+        train_y = np.concatenate(np.array([[y.iloc[:, 1].values for y in train_y]]), axis=0)
+        valid_y = np.concatenate(np.array([[y.iloc[:, 1].values for y in valid_y]]), axis=0)
+
+        return train_x, valid_x, train_y, valid_y
+
+    def get_predict_data(self, dfs, time_series=1):
+        test_x = []
+        for i, df in enumerate(dfs, start=1):
+            if len(df) < self.opt.Model["time_step"]*time_series:
+                self.logger.info("特征处理-预测数据-不满足time_step")
+                continue
+            datax = self.get_predict_features(df, time_series)
+            test_x.append(datax)
+        test_x = np.concatenate(test_x, axis=0)
+        return test_x
+
+    def get_predict_features(self, norm_data, time_series=1):
+        """
+        均分数据,获取预测数据集
+        """
+        time_step = self.opt.Model["time_step"]
+        feature_data = norm_data.loc[:, self.opt.features].reset_index(drop=True)
+        time_step *= int(time_series)
+        time_step_loc = time_step - 1
+        iters = int(len(feature_data)) // time_step
+        end = int(len(feature_data)) % time_step
+        features_x = np.array([feature_data.loc[i*time_step:i*time_step + time_step_loc, self.opt.features1].reset_index(drop=True) for i in range(iters)])
+        features_x1 = [feature_data.loc[i*time_step:i*time_step + time_step_loc, self.opt.features2].reset_index(drop=True) for i in range(iters)]
+
+        if end > 0:
+            df = feature_data.tail(end)
+            df_repeated = pd.concat([df] + [pd.DataFrame([df.iloc[-1]]* (time_step-end))]).reset_index(drop=True)
+            features_x = np.concatenate((features_x, np.expand_dims(df_repeated, 0)), axis=0)
+        return features_x
+
+    def get_timestep_features(self, norm_data, col_time, target, is_train, time_series):
+        """
+        步长分割数据,分区建模
+        """
+        time_step = self.opt.Model["time_step"]
+        feature_data = norm_data.reset_index(drop=True)
+        time_step_loc = time_step*time_series - 1
+        train_num = int(len(feature_data))
+        label_features_power = [col_time, target] if is_train is True else [col_time, target]
+        nwp_cs = self.opt.features1
+        nwp_cs1 = self.opt.features2
+        nwp = [feature_data.loc[i:i + time_step_loc, nwp_cs].reset_index(drop=True) for i in range(train_num - time_step*time_series + 1)]
+        nwp1 = [feature_data.loc[i:i + time_step_loc, nwp_cs1].reset_index(drop=True) for i in range(train_num - time_step*time_series + 1)]
+
+        labels_power = [feature_data.loc[i:i + time_step_loc, label_features_power].reset_index(drop=True) for i in range(train_num - time_step*time_series + 1)]
+        features_x, features_y = [], []
+        for i, row in enumerate(zip(nwp, nwp1, labels_power)):
+            features_x.append([row[0], row[1]])
+            features_y.append(row[2])
+        return features_x, features_y
+
+    def fill_train_data(self, unite, col_time):
+        """
+        补值
+        """
+        unite[col_time] = pd.to_datetime(unite[col_time])
+        unite['time_diff'] = unite[col_time].diff()
+        dt_short = pd.Timedelta(minutes=15)
+        dt_long = pd.Timedelta(minutes=15 * self.opt.Model['how_long_fill'])
+        data_train = self.missing_time_splite(unite, dt_short, dt_long, col_time)
+        miss_points = unite[(unite['time_diff'] > dt_short) & (unite['time_diff'] < dt_long)]
+        miss_number = miss_points['time_diff'].dt.total_seconds().sum(axis=0) / (15 * 60) - len(miss_points)
+        self.logger.info("再次测算,需要插值的总点数为:{}".format(miss_number))
+        if miss_number > 0 and self.opt.Model["train_data_fill"]:
+            data_train = self.data_fill(data_train, col_time)
+        return data_train
+
+    def fill_pre_data(self, unite):
+        unite = unite.interpolate(method='linear')  # nwp先进行线性填充
+        unite = unite.ffill().bfill() # 再对超过采样边缘无法填充的点进行二次填充
+        return unite
+
+    def missing_time_splite(self, df, dt_short, dt_long, col_time):
+        df.reset_index(drop=True, inplace=True)
+        n_long, n_short, n_points = 0, 0, 0
+        start_index = 0
+        dfs = []
+        for i in range(1, len(df)):
+            if df['time_diff'][i] >= dt_long:
+                df_long = df.iloc[start_index:i, :-1]
+                dfs.append(df_long)
+                start_index = i
+                n_long += 1
+            if df['time_diff'][i] > dt_short:
+                self.logger.info(f"{df[col_time][i-1]} ~ {df[col_time][i]}")
+                points = df['time_diff'].dt.total_seconds()[i]/(60*15)-1
+                self.logger.info("缺失点数:{}".format(points))
+                if df['time_diff'][i] < dt_long:
+                    n_short += 1
+                    n_points += points
+                    self.logger.info("需要补值的点数:{}".format(points))
+        dfs.append(df.iloc[start_index:, :-1])
+        self.logger.info(f"数据总数:{len(df)}, 时序缺失的间隔:{n_short}, 其中,较长的时间间隔:{n_long}")
+        self.logger.info("需要补值的总点数:{}".format(n_points))
+        return dfs
+
+    def data_fill(self, dfs, col_time, test=False):
+        dfs_fill, inserts = [], 0
+        for i, df in enumerate(dfs):
+            df = rm_duplicated(df, self.logger)
+            df1 = df.set_index(col_time, inplace=False)
+            dff = df1.resample('15T').interpolate(method='linear')  # 采用线性补值,其他补值方法需要进一步对比
+            dff.reset_index(inplace=True)
+            points = len(dff) - len(df1)
+            dfs_fill.append(dff)
+            self.logger.info("{} ~ {} 有 {} 个点, 填补 {} 个点.".format(dff.iloc[0, 0], dff.iloc[-1, 0], len(dff), points))
+            inserts += points
+        name = "预测数据" if test is True else "训练集"
+        self.logger.info("{}分成了{}段,实际一共补值{}点".format(name, len(dfs_fill), inserts))
+        return dfs_fill
+
+    def train_valid_split(self, datax, datay, valid_rate, shuffle):
+        shuffle_index = np.random.permutation(len(datax))
+        indexs = shuffle_index.tolist() if shuffle else np.arange(0, len(datax)).tolist()
+        valid_size = int(len(datax) * valid_rate)
+        valid_index = indexs[-valid_size:]
+        train_index = indexs[:-valid_size]
+        tx, vx, ty, vy = [], [], [], []
+        for i, data in enumerate(zip(datax, datay)):
+            if i in train_index:
+                tx.append(data[0])
+                ty.append(data[1])
+            elif i in valid_index:
+                vx.append(data[0])
+                vy.append(data[1])
+        return tx, vx, ty, vy
+
+    def train_data_handler(self, data, time_series=1):
+        """
+        训练数据预处理:
+        清洗+补值+归一化
+        Args:
+            data: 从mongo中加载的数据
+            opt:参数命名空间
+        return:
+            x_train
+            x_valid
+            y_train
+            y_valid
+        """
+        col_time, features1, features2, target = self.opt.col_time, self.opt.features1, self.opt.features2, self.opt.target
+        # 清洗限电记录
+        if 'is_limit' in data.columns:
+            data = data[data['is_limit'] == False]
+        # 筛选特征,数值化,排序
+        train_data = data[[col_time] + features1 + features2 + [target]]
+        train_data = train_data.applymap(lambda x: float(x.to_decimal()) if isinstance(x, Decimal128) else float(x) if isinstance(x, numbers.Number) else x)
+        train_data = train_data.sort_values(by=col_time)
+        # 清洗特征平均缺失率大于20%的天
+        # train_data = missing_features(train_data, features, col_time)
+        # 对清洗完限电的数据进行特征预处理:
+        # 1.空值异常值清洗
+        train_data_cleaned = cleaning(train_data, '训练集', self.logger, features1 + features2 + [target], col_time)
+        self.opt.features = [x for x in train_data_cleaned.columns.tolist() if x not in [target, col_time] and x in features1+features2]
+        # 2. 标准化
+        # 创建特征和目标的标准化器
+        train_scaler = MinMaxScaler(feature_range=(0, 1))
+        target_scaler = MinMaxScaler(feature_range=(0, 1))
+        # 标准化特征和目标
+        scaled_train_data = train_scaler.fit_transform(train_data_cleaned[self.opt.features])
+        scaled_target = target_scaler.fit_transform(train_data_cleaned[[target]])
+        scaled_cap = target_scaler.transform(np.array([[float(self.opt.cap)]]))[0,0]
+        train_data_cleaned[self.opt.features] = scaled_train_data
+        train_data_cleaned[[target]] = scaled_target
+        # 3.缺值补值
+        train_datas = self.fill_train_data(train_data_cleaned, col_time)
+        # 保存两个scaler
+        scaled_train_bytes = BytesIO()
+        scaled_target_bytes = BytesIO()
+        joblib.dump(train_scaler, scaled_train_bytes)
+        joblib.dump(target_scaler, scaled_target_bytes)
+        scaled_train_bytes.seek(0)  # Reset pointer to the beginning of the byte stream
+        scaled_target_bytes.seek(0)
+
+        train_x, valid_x, train_y, valid_y = self.get_train_data(train_datas, col_time, target, time_series)
+        return train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap
+
+    def pre_data_handler(self, data, feature_scaler, time_series=1):
+        """
+        预测数据简单处理
+        Args:
+            data: 从mongo中加载的数据
+            opt:参数命名空间
+        return:
+            scaled_features: 反归一化的特征
+        """
+        # 清洗限电记录
+        if 'is_limit' in data.columns:
+            data = data[data['is_limit'] == False]
+        # features, time_steps, col_time, model_name, col_reserve = str_to_list(args['features']), int(
+        #     args['time_steps']), args['col_time'], args['model_name'], str_to_list(args['col_reserve'])
+        col_time, features = self.opt.col_time, self.opt.features
+        data = data.map(lambda x: float(x.to_decimal()) if isinstance(x, Decimal128) else float(x) if isinstance(x, numbers.Number) else x)
+        data = data.sort_values(by=col_time).reset_index(drop=True, inplace=False)
+        if not set(features).issubset(set(data.columns.tolist())):
+            raise ValueError("预测数据特征不满足模型特征!")
+        pre_data = data[features].copy()
+        pre_data[self.opt.zone] = 1
+        if self.opt.Model['predict_data_fill']:
+            pre_data = self.fill_pre_data(pre_data)
+        scaled_pre_data = feature_scaler.transform(pre_data)[:, :len(features)]
+        pre_data.drop(columns=self.opt.zone, inplace=True)
+        pre_data.loc[:, features] = scaled_pre_data
+        pre_x = self.get_predict_data([pre_data], time_series)
+        return pre_x, data

+ 80 - 11
data_processing/data_operation/data_nwp_ftp.py

@@ -16,6 +16,8 @@ import zipfile, tempfile, shutil, fnmatch
 from common.database_dml import insert_data_into_mongo
 from apscheduler.schedulers.background import BackgroundScheduler
 from apscheduler.triggers.cron import CronTrigger
+from ftplib import error_temp, error_perm
+from socket import error as socket_error
 
 from common.logs import Log
 logger = Log('data-processing').logger
@@ -117,6 +119,72 @@ def get_previous_moment(original_date, original_moment):
 
     return new_date, new_moment
 
+
+def safe_ftp_download(ftp, remote_file_path, local_file_path, max_retries=3):
+    file_name = os.path.basename(local_file_path)
+    attempt = 0
+
+    while attempt < max_retries:
+        try:
+            # 初始化下载参数
+            ftp.pwd()
+            ftp.sendcmd("NOOP")  # 保持连接活跃
+            ftp.voidcmd("TYPE I")  # 确保二进制模式
+
+            # 记录开始时间
+            start = time.time()
+            logger.info(f"开始第 {attempt + 1} 次下载尝试: {remote_file_path}")
+
+            # 使用上下文管理器确保文件关闭
+            with open(local_file_path, 'wb') as local_file:
+                # 设置超时和被动模式
+                ftp.timeout = 3000
+                ftp.set_pasv(True)
+
+                # 带进度回调的下载
+                def _callback(data):
+                    local_file.write(data)
+                    logger.debug(f"已接收 {len(data)} 字节")
+
+                ftp.retrbinary(f'RETR {remote_file_path}', _callback)
+
+            # 验证文件完整性
+            remote_size = ftp.size(remote_file_path)
+            local_size = os.path.getsize(local_file_path)
+            if local_size != remote_size:
+                raise IOError(f"文件大小不匹配: 本地 {local_size} vs 远程 {remote_size}")
+
+            # 记录成功日志
+            end = time.time()
+            now = datetime.datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai"))
+            logger_text = f"""下载成功!时间:{now.strftime('%Y-%m-%d %H:%M:%S')}
+                文件:{file_name}
+                耗时:{(end - start) / 60:.2f}分钟
+                平均速度:{(remote_size / 1024 / 1024) / (end - start):.2f}MB/s"""
+            logger.info(logger_text)
+            send_message(file_name, logger_text)
+            return True
+
+        except (error_temp, error_perm, socket_error, IOError) as e:
+            logger.error(f"第 {attempt + 1} 次下载失败: {str(e)}")
+            # 删除不完整文件
+            if os.path.exists(local_file_path):
+                try:
+                    os.remove(local_file_path)
+                    logger.warning(f"已删除不完整文件: {local_file_path}")
+                except Exception as clean_error:
+                    logger.error(f"文件清理失败: {str(clean_error)}")
+            attempt += 1
+            time.sleep(5)  # 重试间隔
+
+        except Exception as unexpected_error:
+            logger.critical(f"未知错误: {str(unexpected_error)}")
+            raise
+    logger_text = f"下载失败: 已达最大重试次数 {max_retries}"
+    logger.error(logger_text)
+    send_message(file_name, logger_text)
+    return False
+
 def download_zip_files_from_ftp(moment=None):
     now = datetime.datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai"))
     date = now.strftime("%Y%m%d")
@@ -140,22 +208,23 @@ def download_zip_files_from_ftp(moment=None):
         # 遍历文件列表,找到ZIP文件并下载
         for file_name in files:
             if fnmatch.fnmatch(file_name, zip_extension):
-                start = time.time()
+                # start = time.time()
                 remote_file_path = os.path.join(remote_dir, file_name)
                 local_file_path = os.path.join(local_dir, file_name)
 
                 if os.path.isfile(local_file_path):
                     continue
-
-                with open(local_file_path, 'wb') as local_file:
-                    logger.info(f"Downloading {remote_file_path} to {local_file_path}")
-                    ftp.retrbinary(f'RETR {remote_file_path}', local_file.write)
-                end = time.time()
-                now = datetime.datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai"))
-                logger_text = f"下载完成时间:{now.strftime('%Y-%m-%d %H:%M:%S')},下载 {file_name} 文件,用时 {(end - start)/60:.2f}分钟"
-                logger.info(logger_text)
-                send_message(file_name, logger_text)
-                zip_file_path.append(local_file_path)
+                if safe_ftp_download(ftp, remote_file_path, local_file_path):
+                # with open(local_file_path, 'wb') as local_file:
+                #     logger.info(f"Downloading {remote_file_path} to {local_file_path}")
+                #     ftp.retrbinary(f'RETR {remote_file_path}', local_file.write)
+
+                # end = time.time()
+                # now = datetime.datetime.now(pytz.utc).astimezone(timezone("Asia/Shanghai"))
+                # logger_text = f"下载完成时间:{now.strftime('%Y-%m-%d %H:%M:%S')},下载 {file_name} 文件,用时 {(end - start)/60:.2f}分钟"
+                # logger.info(logger_text)
+                # send_message(file_name, logger_text)
+                    zip_file_path.append(local_file_path)
     # 解压 ZIP 文件到临时目录
     for zip_file_p in zip_file_path:
         with zipfile.ZipFile(zip_file_p, 'r') as zip_ref:

+ 54 - 0
data_processing/data_operation/hive_to_mongo.py

@@ -0,0 +1,54 @@
+from flask import Flask,request,jsonify
+import time
+import logging
+import traceback
+from common.database_dml import insert_data_into_mongo,get_xmo_data_from_hive
+app = Flask('hive_to_mongo——service')
+
+
+@app.route('/hello', methods=['POST'])
+def hello():
+    return jsonify(message='Hello, World!')
+
+
+@app.route('/hive_to_mongo', methods=['POST'])
+def data_join():
+    # 获取程序开始时间  
+    start_time = time.time()  
+    result = {}
+    success = 0
+    print("Program starts execution!")
+    try:
+        args = request.values.to_dict()
+        print('args', args)
+        logger.info(args)
+        df_hive = get_xmo_data_from_hive(args)
+        if isinstance(df_hive, str):
+            success = 0
+            result['msg'] = df_hive
+        else:
+            insert_data_into_mongo(df_hive, args)
+            success = 1
+    except Exception as e:
+        my_exception = traceback.format_exc()
+        my_exception.replace("\n", "\t")
+        result['msg'] = my_exception
+    end_time = time.time()
+    result['success'] = success
+    result['args'] = args
+    result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
+    result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
+    print("Program execution ends!")
+    return result
+
+
+if __name__ == "__main__":
+    print("Program starts execution!")
+    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    logger = logging.getLogger("hive_to_mongo")
+    from waitress import serve
+    serve(app, host="0.0.0.0", port=10127)
+    print("server start!")
+    
+   
+    

+ 6 - 54
evaluation_processing/evaluation_accuracy.py

@@ -8,6 +8,9 @@ from flask import Flask, request
 import time
 import logging
 import traceback
+
+from common.database_dml import get_data_from_mongo, insert_data_into_mongo
+
 app = Flask('evaluation_accuracy——service')
 url = 'http://49.4.78.194:17160/apiCalculate/calculate'
 '''
@@ -97,66 +100,15 @@ def datetime_to_timestamp(dt):
     return int(round(time.mktime(dt.timetuple()))*1000)
 
 
-
-def get_data_from_mongo(args):
-    mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table']
-    client = MongoClient(mongodb_connection)
-    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
-    db = client[mongodb_database]
-    collection = db[mongodb_read_table]  # 集合名称
-    data_from_db = collection.find()  # 这会返回一个游标(cursor)
-    # 将游标转换为列表,并创建 pandas DataFrame
-    df = pd.DataFrame(list(data_from_db))
-    client.close()
-    return df
-    
-
-def insert_data_into_mongo(res_df,args):
-    mongodb_connection,mongodb_database,mongodb_write_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_write_table']
-    client = MongoClient(mongodb_connection)
-    db = client[mongodb_database]
-    if mongodb_write_table in db.list_collection_names():
-        db[mongodb_write_table].drop()
-        print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
-    collection = db[mongodb_write_table]  # 集合名称
-    # 将 DataFrame 转为字典格式
-    data_dict = res_df.to_dict("records")  # 每一行作为一个字典
-    # 插入到 MongoDB
-    collection.insert_many(data_dict)
-    print("data inserted successfully!")
-    
-
-# def compute_accuracy(df,args):
-#     col_time,col_rp,col_pp,formulaType = args['col_time'],args['col_rp'],args['col_pp'],args['formulaType'].split('_')[0]
-#     dates = []
-#     accuracy = []
-#     df = df[(~np.isnan(df[col_rp]))&(~np.isnan(df[col_pp]))]
-#     df = df[[col_time,col_rp,col_pp]].rename(columns={col_time:'C_TIME',col_rp:'realValue',col_pp:'forecastAbleValue'})
-#     df['ableValue'] = df['realValue']
-#     df['C_TIME'] = df['C_TIME'].apply(lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S"))
-#     if formulaType=='DAY':
-#         df['C_DATE'] = df['C_TIME'].apply(lambda x: x.strftime("%Y-%m-%d"))
-#         days_list = df['C_DATE'].unique().tolist()
-#         for day in days_list:
-#             df_tmp = df[df['C_DATE'] == day]
-#             dates.append(day)
-#             accuracy.append(calculate_acc(df_tmp, args))
-#     else:
-#         points = df['C_TIME'].unique().tolist()
-#         for point in points:
-#             df_tmp = df[df['C_TIME'] == point]
-#             dates.append(point)
-#             accuracy.append(calculate_acc(df_tmp, args))
-#     print("accuray compute successfully!")
-#     return pd.DataFrame({'date':dates,'accuracy':accuracy})
-
 # 定义 RMSE 和 MAE 计算函数
 def rmse(y_true, y_pred):
     return np.sqrt(np.mean((y_true - y_pred) ** 2))
 
+
 def mae(y_true, y_pred):
     return np.mean(np.abs(y_true - y_pred))
-    
+
+
 def compute_accuracy(df,args):
     col_time,col_rp,col_pp = args['col_time'],args['col_rp'],args['col_pp']
     df[col_time] = df[col_time].apply(lambda x:pd.to_datetime(x).strftime("%Y-%m-%d")) 

+ 6 - 21
models_processing/model_predict/model_prediction_lightgbm.py

@@ -5,7 +5,7 @@ from flask import Flask,request
 import time
 import logging
 import traceback
-from common.database_dml import get_data_from_mongo,insert_data_into_mongo
+from common.database_dml import get_data_from_mongo, insert_data_into_mongo, get_pickle_model_from_mongo
 from common.alert import send_message
 from datetime import datetime, timedelta
 import pytz
@@ -44,16 +44,8 @@ def forecast_data_distribution(pre_data,args):
         result = get_xxl_dq(farm_id, dt)
     else:
         df = pre_data.sort_values(by=col_time).fillna(method='ffill').fillna(method='bfill')
-        mongodb_connection, mongodb_database, mongodb_model_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
-        args['mongodb_database'], args['mongodb_model_table'], args['model_name']
-        client = MongoClient(mongodb_connection)
-        db = client[mongodb_database]
-        collection = db[mongodb_model_table]
-        model_data = collection.find_one({"model_name": model_name})
-        if model_data is not None:
-            model_binary = model_data['model']  # 确保这个字段是存储模型的二进制数据
-            # 反序列化模型
-            model = pickle.loads(model_binary)
+        model = get_pickle_model_from_mongo(args)
+        if model is not None:
             diff = set(model.feature_name()) - set(pre_data.columns)
             if len(diff) > 0:
                 send_message('lightgbm预测组件', farm_id, f'NWP特征列缺失,使用DQ代替!features:{diff}')
@@ -73,18 +65,11 @@ def forecast_data_distribution(pre_data,args):
 
 
 def model_prediction(df,args):
-    mongodb_connection,mongodb_database,mongodb_model_table,model_name,howLongAgo,farm_id,target = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_model_table'],args['model_name'],int(args['howLongAgo']),args['farm_id'],args['target']
-    client = MongoClient(mongodb_connection)
-    db = client[mongodb_database]
-    collection = db[mongodb_model_table]
-    model_data = collection.find_one({"model_name": model_name})
+    model_name, howLongAgo, farm_id, target = args['model_name'], int(args['howLongAgo']), args['farm_id'], args['target']
     if 'is_limit' in df.columns:
         df = df[df['is_limit'] == False]
-
-    if model_data is not None:
-        model_binary = model_data['model']  # 确保这个字段是存储模型的二进制数据
-        # 反序列化模型 
-        model = pickle.loads(model_binary)
+    model = get_pickle_model_from_mongo(args)
+    if model is not None:
         df['power_forecast'] = model.predict(df[model.feature_name()])
         df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
         df['model'] = model_name

+ 96 - 0
models_processing/model_predict/model_prediction_photovoltaic_physical.py

@@ -0,0 +1,96 @@
+import pandas as pd
+from pymongo import MongoClient
+import pickle
+from flask import Flask, request
+import time
+import logging
+import traceback
+from common.database_dml import get_data_from_mongo, insert_data_into_mongo
+
+
+app = Flask('model_prediction_photovoltaic_physical——service')
+
+
+def str_to_list(arg):
+    if arg == '':
+        return []
+    else:
+        return arg.split(',')
+
+
+def forecast_data_distribution(pre_data, args):
+    col_time = args['col_time']
+    farm_id = args['farm_id']
+    col_radiance = args['col_radiance']
+    radiance_max = float(args['radiance_max'])
+    cap = float(args['cap'])
+    pre_data['farm_id'] = farm_id
+    pre_data['power_forecast'] = round(pre_data[col_radiance] * cap / radiance_max, 2)
+    if 'sunrise_time' in args:
+        sunrise_time = args['sunrise_time']
+        pre_data.loc[pre_data[col_time].dt.time < sunrise_time, 'power_forecast'] = 0
+    if 'sunset_time' in args:
+        sunset_time = args['sunset_time']
+        pre_data[pre_data[col_time] > sunset_time, 'power_forecast'] = 0
+    return pre_data[['farm_id', 'date_time', 'power_forecast']]
+
+
+def model_prediction(df, args):
+    # 新增日出、日落时间参数
+    howLongAgo, farm_id, target, cap, col_radiance, radiance_max, model_name, col_time = int(args['howLongAgo']), args['farm_id'], \
+    args['target'], args['cap'], args['col_radiance'], args['radiance_max'], args['model_name'], args['col_time']
+    df['power_forecast'] = round(df[col_radiance]*cap/radiance_max, 2)
+    df.loc[df['power_forecast'] < 0, 'power_forecast'] = 0
+    if 'sunrise_time' in args:
+        sunrise_time = args['sunrise_time']
+        df.loc[df[col_time].dt.time < sunrise_time, 'power_forecast'] = 0
+    if 'sunset_time' in args:
+        sunset_time = args['sunset_time']
+        df[df[col_time] > sunset_time, 'power_forecast'] = 0
+    df['model'] = model_name
+    df['howLongAgo'] = howLongAgo
+    df['farm_id'] = farm_id
+    print("model predict result  successfully!")
+    return df[['dateTime', 'howLongAgo', 'model', 'farm_id', 'power_forecast', target]]
+
+
+@app.route('/model_prediction_photovoltaic_physical', methods=['POST'])
+def model_prediction_photovoltaic_physical():
+    # 获取程序开始时间
+    start_time = time.time()
+    result = {}
+    success = 0
+    print("Program starts execution!")
+    try:
+        args = request.values.to_dict()
+        print('args', args)
+        logger.info(args)
+        forecast_file = int(args['forecast_file'])
+        power_df = get_data_from_mongo(args)
+        if forecast_file == 1:
+            predict_data = forecast_data_distribution(power_df, args)
+        else:
+            predict_data = model_prediction(power_df, args)
+        insert_data_into_mongo(predict_data, args)
+        success = 1
+    except Exception as e:
+        my_exception = traceback.format_exc()
+        my_exception.replace("\n", "\t")
+        result['msg'] = my_exception
+    end_time = time.time()
+    result['success'] = success
+    result['args'] = args
+    result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
+    result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
+    print("Program execution ends!")
+    return result
+
+
+if __name__ == "__main__":
+    print("Program starts execution!")
+    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    logger = logging.getLogger("model_prediction_photovoltaic_physical log")
+    from waitress import serve
+
+    serve(app, host="0.0.0.0", port=10126)
+    print("server start!")

+ 37 - 6
models_processing/model_tf/losses.py

@@ -115,8 +115,8 @@ class SouthLoss(Loss):
                  name: str = "south_loss",
                  reduction: str = "sum_over_batch_size"):
         # 参数校验
-        # if not 0 <= cap <= 1:
-        #     raise ValueError("cap 必须为归一化后的值且位于 [0,1] 区间")
+        if not 0 <= cap <= 1:
+            raise ValueError("cap 必须为归一化后的值且位于 [0,1] 区间")
 
         super().__init__(name=name, reduction=reduction)
 
@@ -195,10 +195,41 @@ class SouthLoss(Loss):
 
 
 
+class NorthChina(Loss):
+    """Root Mean Squared Error 损失函数(兼容单变量/多变量回归)"""
+
+    def __init__(self,
+                 name="north_china_loss",
+                 reduction="sum_over_batch_size",  # 默认自动选择 'sum_over_batch_size' (等效于 mean)
+                 **kwargs):
+        super().__init__(name=name, reduction=reduction)
+
+    def call(self, y_true, y_pred):
+        # 计算误差 e = y_true - y_pred
+        error = y_true - y_pred
+        abs_error = tf.abs(error)
+
+        # 加上 epsilon 避免除以 0
+        epsilon = 1e-8
+        weight = abs_error / (tf.reduce_sum(abs_error) + epsilon)
+
+        weighted_squared_error = tf.square(error) * weight
+        loss = tf.sqrt(tf.reduce_sum(weighted_squared_error))
+        return loss
+
+    def get_config(self):
+        """支持序列化配置(用于模型保存/加载)"""
+        base_config = super().get_config()
+        return base_config
+
+
+
+
 region_loss_d = {
     'northeast': lambda region: RMSE(region),
     'south': lambda cap, region: SouthLoss(cap, region),
-    'zone': lambda region: MSE_ZONE(region) # 分区建模损失:MSE + 分区总和一致性约束
+    'zone': lambda region: MSE_ZONE(region), # 分区建模损失:MSE + 分区总和一致性约束
+    'northchina': lambda region: NorthChina(region) #华北损失函数
 }
 
 
@@ -224,18 +255,18 @@ if __name__ == '__main__':
     mse = tf.keras.losses.MeanSquaredError()(y_true, y_pred).numpy()
 
     # 自定义损失(权重=1时等效于MSE)
-    custom_mse = MSE(name='test')(y_true, y_pred).numpy()
+    custom_mse = NorthChina(name='test')(y_true, y_pred).numpy()
 
     print("标准 MSE:", mse)  # 输出: 0.25
     print("自定义 MSE:", custom_mse)  # 应输出: 0.25
-    assert abs(mse - custom_mse) < 1e-6
+    # assert abs(mse - custom_mse) < 1e-6
 
     # 定义变量和优化器
     y_pred_var = tf.Variable([[1.5], [2.5], [3.5]], dtype=tf.float32)
     optimizer = tf.keras.optimizers.Adam()
 
     with tf.GradientTape() as tape:
-        loss = MSE(name='test')(y_true, y_pred_var)
+        loss = NorthChina(name='test')(y_true, y_pred_var)
     grads = tape.gradient(loss, y_pred_var)
 
     # 理论梯度公式:2*(y_pred - y_true)/N (N=3)

+ 1 - 1
models_processing/model_tf/lstm.yaml

@@ -2,7 +2,7 @@ Model:
   add_train: true
   batch_size: 64
   dropout_rate: 0.2
-  epoch: 200
+  epoch: 500
   fusion: true
   hidden_size: 64
   his_points: 16

+ 1 - 1
models_processing/model_tf/tf_bilstm.py

@@ -11,7 +11,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoa
 from tensorflow.keras import optimizers, regularizers
 from models_processing.model_tf.losses import region_loss
 import numpy as np
-from common.database_dml_koi import *
+from common.database_dml import *
 from models_processing.model_tf.settings import set_deterministic
 from threading import Lock
 import argparse

+ 1 - 1
models_processing/model_tf/tf_bilstm_2.py

@@ -11,7 +11,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoa
 from tensorflow.keras import optimizers, regularizers
 from models_processing.model_tf.losses import region_loss
 import numpy as np
-from common.database_dml_koi import *
+from common.database_dml import *
 from models_processing.model_tf.settings import set_deterministic
 from threading import Lock
 import argparse

+ 1 - 1
models_processing/model_tf/tf_bp.py

@@ -13,7 +13,7 @@ from tensorflow.keras import optimizers, regularizers
 from models_processing.model_tf.losses import region_loss
 from models_processing.model_tf.settings import set_deterministic
 import numpy as np
-from common.database_dml_koi import *
+from common.database_dml import *
 from threading import Lock
 import argparse
 model_lock = Lock()

+ 4 - 3
models_processing/model_tf/tf_bp_pre.py

@@ -8,7 +8,7 @@ import json, copy
 import numpy as np
 from flask import Flask, request, g
 import logging, argparse, traceback
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.processing_data_common import missing_features, str_to_list
 from data_processing.data_operation.data_handler import DataHandler
 from threading import Lock
@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 from tf_bp import BPHandler
 logger = Log('tf_bp').logger
 np.random.seed(42)  # NumPy随机种子
@@ -35,7 +36,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
@@ -82,7 +83,7 @@ def model_prediction_bp():
             res_cols = ['date_time', 'power_forecast', 'farm_id']
 
         pre_data = pre_data[res_cols]
-        pre_data.loc[:, 'power_forecast'] = pre_data['power_forecast'].round(2)
+        pre_data.loc[:, 'power_forecast'] = pre_data.loc[:, 'power_forecast'].apply(lambda x: float(f"{x:.2f}"))
         pre_data.loc[pre_data['power_forecast'] > float(args['cap']), 'power_forecast'] = float(args['cap'])
         pre_data.loc[pre_data['power_forecast'] < 0, 'power_forecast'] = 0
 

+ 3 - 2
models_processing/model_tf/tf_bp_train.py

@@ -14,9 +14,10 @@ from data_processing.data_operation.data_handler import DataHandler
 import time, yaml
 from copy import deepcopy
 from models_processing.model_tf.tf_bp import BPHandler
-from common.database_dml_koi import *
+from common.database_dml import *
 import matplotlib.pyplot as plt
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_bp').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_bp_train——service')
@@ -34,7 +35,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 1 - 1
models_processing/model_tf/tf_cnn.py

@@ -13,7 +13,7 @@ from tensorflow.keras import optimizers, regularizers
 from models_processing.model_tf.losses import region_loss
 from models_processing.model_tf.settings import set_deterministic
 import numpy as np
-from common.database_dml_koi import *
+from common.database_dml import *
 from threading import Lock
 import argparse
 model_lock = Lock()

+ 4 - 3
models_processing/model_tf/tf_cnn_pre.py

@@ -8,7 +8,7 @@ import json, copy
 import numpy as np
 from flask import Flask, request, g
 import logging, argparse, traceback
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.processing_data_common import missing_features, str_to_list
 from data_processing.data_operation.data_handler import DataHandler
 from threading import Lock
@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 from tf_cnn import CNNHandler
 # logger = Log('tf_bp').logger()
 logger = Log('tf_cnn').logger
@@ -36,7 +37,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
@@ -83,7 +84,7 @@ def model_prediction_cnn():
             res_cols = ['date_time', 'power_forecast', 'farm_id']
         pre_data = pre_data[res_cols]
 
-        pre_data.loc[:, 'power_forecast'] = pre_data['power_forecast'].round(2)
+        pre_data.loc[:, 'power_forecast'] = pre_data.loc[:, 'power_forecast'].apply(lambda x: float(f"{x:.2f}"))
         pre_data.loc[pre_data['power_forecast'] > float(args['cap']), 'power_forecast'] = float(args['cap'])
         pre_data.loc[pre_data['power_forecast'] < 0, 'power_forecast'] = 0
 

+ 3 - 2
models_processing/model_tf/tf_cnn_train.py

@@ -13,9 +13,10 @@ from data_processing.data_operation.data_handler import DataHandler
 import time, yaml
 from copy import deepcopy
 from models_processing.model_tf.tf_cnn import CNNHandler
-from common.database_dml_koi import *
+from common.database_dml import *
 import matplotlib.pyplot as plt
 from common.logs import Log
+from common.data_utils import deep_update
 # logger = logging.getLogger()
 logger = Log('tf_cnn').logger
 np.random.seed(42)  # NumPy随机种子
@@ -33,7 +34,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 1 - 1
models_processing/model_tf/tf_lstm.py

@@ -11,7 +11,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoa
 from tensorflow.keras import optimizers, regularizers
 from models_processing.model_tf.losses import region_loss
 import numpy as np
-from common.database_dml_koi import *
+from common.database_dml import *
 from models_processing.model_tf.settings import set_deterministic
 from threading import Lock
 import argparse

+ 4 - 4
models_processing/model_tf/tf_lstm2_pre.py

@@ -8,7 +8,7 @@ import json, copy
 import numpy as np
 from flask import Flask, request, g
 import logging, argparse, traceback
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.processing_data_common import missing_features, str_to_list
 from data_processing.data_operation.data_handler import DataHandler
 from threading import Lock
@@ -18,6 +18,7 @@ model_lock = Lock()
 from itertools import chain
 from common.logs import Log
 from tf_lstm import TSHandler
+from common.data_utils import deep_update
 # logger = Log('tf_bp').logger()
 logger = Log('tf_ts2').logger
 np.random.seed(42)  # NumPy随机种子
@@ -37,8 +38,7 @@ def update_config():
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 2)
-    current_config.update(request_args)
-
+    current_config = deep_update(current_config, request_args)
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
     g.dh = DataHandler(logger, current_config)  # 每个请求独立实例
@@ -83,7 +83,7 @@ def model_prediction_lstm2():
             res_cols = ['date_time', 'power_forecast', 'farm_id']
         pre_data = pre_data[res_cols]
 
-        pre_data.loc[:, 'power_forecast'] = pre_data['power_forecast'].round(2)
+        pre_data.loc[:, 'power_forecast'] = pre_data.loc[:, 'power_forecast'].apply(lambda x: float(f"{x:.2f}"))
         pre_data.loc[pre_data['power_forecast'] > float(args['cap']), 'power_forecast'] = float(args['cap'])
         pre_data.loc[pre_data['power_forecast'] < 0, 'power_forecast'] = 0
 

+ 4 - 3
models_processing/model_tf/tf_lstm2_train.py

@@ -13,8 +13,10 @@ from data_processing.data_operation.data_handler import DataHandler
 import time, yaml, threading
 from copy import deepcopy
 from models_processing.model_tf.tf_lstm import TSHandler
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.logs import Log
+from common.data_utils import deep_update
+
 logger = Log('tf_ts2').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm2_train——service')
@@ -32,8 +34,7 @@ def update_config():
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 2)
-    current_config.update(request_args)
-
+    current_config = deep_update(current_config, request_args)
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
     g.dh = DataHandler(logger, current_config)  # 每个请求独立实例

+ 4 - 4
models_processing/model_tf/tf_lstm3_pre.py

@@ -8,7 +8,7 @@ import json, copy
 import numpy as np
 from flask import Flask, request, g
 import logging, argparse, traceback
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.processing_data_common import missing_features, str_to_list
 from data_processing.data_operation.data_handler import DataHandler
 from threading import Lock
@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_ts3').logger
 np.random.seed(42)  # NumPy随机种子
 # tf.set_random_seed(42)  # TensorFlow随机种子
@@ -36,8 +37,7 @@ def update_config():
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 3)
     request_args['lstm_type'] = request_args.get('lstm_type', 1)
-    current_config.update(request_args)
-
+    current_config = deep_update(current_config, request_args)
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
     g.dh = DataHandler(logger, current_config)  # 每个请求独立实例
@@ -87,7 +87,7 @@ def model_prediction_lstm3():
             res_cols = ['date_time', 'power_forecast', 'farm_id']
         pre_data = pre_data[res_cols]
 
-        pre_data.loc[:, 'power_forecast'] = pre_data['power_forecast'].round(2)
+        pre_data.loc[:, 'power_forecast'] = pre_data.loc[:, 'power_forecast'].apply(lambda x: float(f"{x:.2f}"))
         pre_data.loc[pre_data['power_forecast'] > float(args['cap']), 'power_forecast'] = float(args['cap'])
         pre_data.loc[pre_data['power_forecast'] < 0, 'power_forecast'] = 0
 

+ 3 - 2
models_processing/model_tf/tf_lstm3_train.py

@@ -12,8 +12,9 @@ import logging, argparse
 from data_processing.data_operation.data_handler import DataHandler
 import time, yaml, threading
 from copy import deepcopy
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_ts3').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm3_train——service')
@@ -32,7 +33,7 @@ def update_config():
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 3)
     request_args['lstm_type'] = request_args.get('lstm_type', 1)
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 4 - 3
models_processing/model_tf/tf_lstm_pre.py

@@ -8,7 +8,7 @@ import json, copy
 import numpy as np
 from flask import Flask, request, g
 import logging, argparse, traceback
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.processing_data_common import missing_features, str_to_list
 from data_processing.data_operation.data_handler import DataHandler
 from threading import Lock
@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 from tf_lstm import TSHandler
 # logger = Log('tf_bp').logger()
 logger = Log('tf_ts').logger
@@ -37,7 +38,7 @@ def update_config():
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 1)
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
@@ -81,7 +82,7 @@ def model_prediction_lstm():
             res_cols = ['date_time', 'power_forecast', 'farm_id']
         pre_data = pre_data[res_cols]
 
-        pre_data.loc[:, 'power_forecast'] = pre_data['power_forecast'].round(2)
+        pre_data.loc[:, 'power_forecast'] = pre_data.loc[:, 'power_forecast'].apply(lambda x: float(f"{x:.2f}"))
         pre_data.loc[pre_data['power_forecast'] > float(args['cap']), 'power_forecast'] = float(args['cap'])
         pre_data.loc[pre_data['power_forecast'] < 0, 'power_forecast'] = 0
 

+ 3 - 2
models_processing/model_tf/tf_lstm_train.py

@@ -13,8 +13,9 @@ from data_processing.data_operation.data_handler import DataHandler
 import time, yaml, threading
 from copy import deepcopy
 from models_processing.model_tf.tf_lstm import TSHandler
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_ts').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm_train——service')
@@ -32,7 +33,7 @@ def update_config():
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 1)
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 1 - 1
models_processing/model_tf/tf_lstm_zone.py

@@ -11,7 +11,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoa
 from tensorflow.keras import optimizers, regularizers
 from models_processing.model_tf.losses import region_loss
 import numpy as np
-from common.database_dml_koi import *
+from common.database_dml import *
 from models_processing.model_tf.settings import set_deterministic
 from threading import Lock
 import argparse

+ 4 - 4
models_processing/model_tf/tf_lstm_zone_pre.py

@@ -8,7 +8,7 @@ import json, copy
 import numpy as np
 from flask import Flask, request, g
 import logging, argparse, traceback
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.processing_data_common import missing_features, str_to_list
 from data_processing.data_operation.custom_data_handler import CustomDataHandler
 from models_processing.model_tf.tf_lstm_zone import TSHandler
@@ -18,6 +18,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 # logger = Log('tf_bp').logger()
 logger = Log('tf_ts').logger
 np.random.seed(42)  # NumPy随机种子
@@ -38,8 +39,7 @@ def update_config():
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 1)
     request_args['zone'] = request_args['zone'].split(',')
-    current_config.update(request_args)
-
+    current_config = deep_update(current_config, request_args)
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
     g.dh = CustomDataHandler(logger, current_config)  # 每个请求独立实例
@@ -82,7 +82,7 @@ def model_prediction_lstm():
             res_cols = ['date_time', 'power_forecast', 'farm_id']
         pre_data = pre_data[res_cols]
 
-        pre_data.loc[:, 'power_forecast'] = pre_data['power_forecast'].round(2)
+        pre_data.loc[:, 'power_forecast'] = pre_data.loc[:, 'power_forecast'].apply(lambda x: float(f"{x:.2f}"))
         pre_data.loc[pre_data['power_forecast'] > float(args['cap']), 'power_forecast'] = float(args['cap'])
         pre_data.loc[pre_data['power_forecast'] < 0, 'power_forecast'] = 0
 

+ 3 - 2
models_processing/model_tf/tf_lstm_zone_train.py

@@ -13,8 +13,9 @@ from data_processing.data_operation.custom_data_handler import CustomDataHandler
 import time, yaml, threading
 from copy import deepcopy
 from models_processing.model_tf.tf_lstm_zone import TSHandler
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_ts').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm_zone_train——service')
@@ -33,7 +34,7 @@ def update_config():
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 1)
     request_args['zone'] = request_args['zone'].split(',')
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 143 - 0
models_processing/model_tf/tf_multi_nwp_pre.py

@@ -0,0 +1,143 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# @FileName  :tf_lstm_pre.py
+# @Time      :2025/2/13 10:52
+# @Author    :David
+# @Company: shenyang JY
+import json, copy
+import numpy as np
+from flask import Flask, request, g
+import logging, argparse, traceback
+from common.database_dml import *
+from common.processing_data_common import missing_features, str_to_list
+from data_processing.data_operation.custom_data_handler import CustomDataHandler
+from models_processing.model_tf.tf_lstm_zone import TSHandler
+from threading import Lock
+import time, yaml
+from copy import deepcopy
+model_lock = Lock()
+from itertools import chain
+from common.logs import Log
+from common.data_utils import deep_update
+# logger = Log('tf_bp').logger()
+logger = Log('tf_ts').logger
+np.random.seed(42)  # NumPy随机种子
+# tf.set_random_seed(42)  # TensorFlow随机种子
+app = Flask('tf_lstm_zone_pre——service')
+
+current_dir = os.path.dirname(os.path.abspath(__file__))
+with open(os.path.join(current_dir, 'lstm.yaml'), 'r', encoding='utf-8') as f:
+    global_config = yaml.safe_load(f)  # 只读的全局配置
+
+@app.before_request
+def update_config():
+    # ------------ 整理参数,整合请求参数 ------------
+    # 深拷贝全局配置 + 合并请求参数
+    current_config = deepcopy(global_config)
+    request_args = request.values.to_dict()
+    # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
+    request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
+    request_args['time_series'] = request_args.get('time_series', 1)
+    request_args['zone'] = request_args['zone'].split(',')
+    current_config = deep_update(current_config, request_args)
+    # 存储到请求上下文
+    g.opt = argparse.Namespace(**current_config)
+    g.dh = CustomDataHandler(logger, current_config)  # 每个请求独立实例
+    g.ts = TSHandler(logger, current_config)
+
+@app.route('/tf_lstm_zone_predict', methods=['POST'])
+def model_prediction_lstm():
+    # 获取程序开始时间
+    start_time = time.time()
+    result = {}
+    success = 0
+    dh = g.dh
+    ts = g.ts
+    args = deepcopy(g.opt.__dict__)
+    logger.info("Program starts execution!")
+    try:
+        pre_data = get_data_from_mongo(args)
+        if args.get('algorithm_test', 0):
+            field_mapping = {'clearsky_ghi': 'clearskyGhi', 'dni_calcd': 'dniCalcd','surface_pressure': 'surfacePressure'}
+            pre_data = pre_data.rename(columns=field_mapping)
+        feature_scaler, target_scaler = get_scaler_model_from_mongo(args)
+        ts.opt.cap = round(target_scaler.transform(np.array([[float(args['cap'])]]))[0, 0], 2)
+        ts.get_model(args)
+        dh.opt.features = json.loads(ts.model_params)['Model']['features'].split(',')
+        scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler, time_series=args['time_series'])
+        res = list(chain.from_iterable(target_scaler.inverse_transform(ts.predict(scaled_pre_x)[1])))
+        pre_data['farm_id'] = args.get('farm_id', 'null')
+        if int(args.get('algorithm_test', 0)):
+            pre_data[args['model_name']] = res[:len(pre_data)]
+            pre_data.rename(columns={args['col_time']: 'dateTime'}, inplace=True)
+            pre_data = pre_data[['dateTime', 'farm_id', args['target'], args['model_name'], 'dq']]
+            pre_data = pre_data.melt(id_vars=['dateTime', 'farm_id', args['target']], var_name='model', value_name='power_forecast')
+            res_cols = ['dateTime', 'power_forecast', 'farm_id', args['target'], 'model']
+            if 'howLongAgo' in args:
+                pre_data['howLongAgo'] = int(args['howLongAgo'])
+                res_cols += ['howLongAgo']
+        else:
+            pre_data['power_forecast'] = res[:len(pre_data)]
+            pre_data.rename(columns={args['col_time']: 'date_time'}, inplace=True)
+            res_cols = ['date_time', 'power_forecast', 'farm_id']
+        pre_data = pre_data[res_cols]
+
+        pre_data.loc[:, 'power_forecast'] = pre_data.loc[:, 'power_forecast'].apply(lambda x: float(f"{x:.2f}"))
+        pre_data.loc[pre_data['power_forecast'] > float(args['cap']), 'power_forecast'] = float(args['cap'])
+        pre_data.loc[pre_data['power_forecast'] < 0, 'power_forecast'] = 0
+
+        insert_data_into_mongo(pre_data, args)
+        success = 1
+    except Exception as e:
+        my_exception = traceback.format_exc()
+        my_exception.replace("\n", "\t")
+        result['msg'] = my_exception
+    end_time = time.time()
+
+    result['success'] = success
+    result['args'] = args
+    result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
+    result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
+    print("Program execution ends!")
+    return result
+
+
+if __name__ == "__main__":
+    print("Program starts execution!")
+    from waitress import serve
+    serve(app, host="0.0.0.0", port=10125,
+          threads=8,  # 指定线程数(默认4,根据硬件调整)
+          channel_timeout=600  # 连接超时时间(秒)
+          )
+    print("server start!")
+
+    # ------------------------测试代码------------------------
+    # args_dict = {"mongodb_database": 'david_test', 'scaler_table': 'j00083_scaler', 'model_name': 'bp1.0.test',
+    #              'model_table': 'j00083_model', 'mongodb_read_table': 'j00083_test', 'col_time': 'date_time', 'mongodb_write_table': 'j00083_rs',
+    #              'features': 'speed10,direction10,speed30,direction30,speed50,direction50,speed70,direction70,speed90,direction90,speed110,direction110,speed150,direction150,speed170,direction170'}
+    # args_dict['features'] = args_dict['features'].split(',')
+    # arguments.update(args_dict)
+    # dh = DataHandler(logger, arguments)
+    # ts = TSHandler(logger)
+    # opt = argparse.Namespace(**arguments)
+    #
+    # opt.Model['input_size'] = len(opt.features)
+    # pre_data = get_data_from_mongo(args_dict)
+    # feature_scaler, target_scaler = get_scaler_model_from_mongo(arguments)
+    # pre_x = dh.pre_data_handler(pre_data, feature_scaler, opt)
+    # ts.get_model(arguments)
+    # result = ts.predict(pre_x)
+    # result1 = list(chain.from_iterable(target_scaler.inverse_transform([result.flatten()])))
+    # pre_data['power_forecast'] = result1[:len(pre_data)]
+    # pre_data['farm_id'] = 'J00083'
+    # pre_data['cdq'] = 1
+    # pre_data['dq'] = 1
+    # pre_data['zq'] = 1
+    # pre_data.rename(columns={arguments['col_time']: 'date_time'}, inplace=True)
+    # pre_data = pre_data[['date_time', 'power_forecast', 'farm_id', 'cdq', 'dq', 'zq']]
+    #
+    # pre_data['power_forecast'] = pre_data['power_forecast'].round(2)
+    # pre_data.loc[pre_data['power_forecast'] > opt.cap, 'power_forecast'] = opt.cap
+    # pre_data.loc[pre_data['power_forecast'] < 0, 'power_forecast'] = 0
+    #
+    # insert_data_into_mongo(pre_data, arguments)

+ 123 - 0
models_processing/model_tf/tf_multi_nwp_train.py

@@ -0,0 +1,123 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# @FileName  :tf_lstm_train.py
+# @Time      :2025/2/13 10:52
+# @Author    :David
+# @Company: shenyang JY
+import json, copy
+import numpy as np
+from flask import Flask, request, jsonify, g
+import traceback, uuid
+import logging, argparse
+from data_processing.data_operation.custom_data_handler import CustomDataHandler
+import time, yaml, threading
+from copy import deepcopy
+from models_processing.model_tf.tf_lstm_zone import TSHandler
+from common.database_dml import *
+from common.logs import Log
+from common.data_utils import deep_update
+logger = Log('tf_ts').logger
+np.random.seed(42)  # NumPy随机种子
+app = Flask('tf_lstm_zone_train——service')
+
+current_dir = os.path.dirname(os.path.abspath(__file__))
+with open(os.path.join(current_dir, 'lstm.yaml'), 'r', encoding='utf-8') as f:
+    global_config = yaml.safe_load(f)  # 只读的全局配置
+
+@app.before_request
+def update_config():
+    # ------------ 整理参数,整合请求参数 ------------
+    # 深拷贝全局配置 + 合并请求参数
+    current_config = deepcopy(global_config)
+    request_args = request.values.to_dict()
+    # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
+    request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
+    request_args['time_series'] = request_args.get('time_series', 1)
+    current_config = deep_update(current_config, request_args)
+
+    # 存储到请求上下文
+    g.opt = argparse.Namespace(**current_config)
+    g.dh = CustomDataHandler(logger, current_config)  # 每个请求独立实例
+    g.ts = TSHandler(logger, current_config)
+
+
+@app.route('/tf_lstm_zone_training', methods=['POST'])
+def model_training_lstm():
+    # 获取程序开始时间
+    start_time = time.time()
+    result = {}
+    success = 0
+    dh = g.dh
+    ts = g.ts
+    args = deepcopy(g.opt.__dict__)
+    logger.info("Program starts execution!")
+    try:
+        # ------------ 获取数据,预处理训练数据 ------------
+        train_data = get_data_from_mongo(args)
+        train_data_1 = get_data_from_mongo(args)
+        train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data, time_series=args['time_series'])
+        ts.opt.cap = round(scaled_cap, 2)
+        ts.opt.Model['input_size'] = len(dh.opt.features)
+        # ------------ 训练模型,保存模型 ------------
+        # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
+        # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
+        model = ts.train_init() if ts.opt.Model['add_train'] else ts.get_keras_model(ts.opt, time_series=args['time_series'], lstm_type=1)
+        if ts.opt.Model['add_train']:
+            if model:
+                feas = json.loads(ts.model_params)['features']
+                if set(feas).issubset(set(dh.opt.features)):
+                    dh.opt.features = list(feas)
+                    train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data, time_series=args['time_series'])
+                else:
+                    model = ts.get_keras_model(ts.opt, time_series=args['time_series'], lstm_type=1)
+                    logger.info("训练数据特征,不满足,加强训练模型特征")
+            else:
+                model = ts.get_keras_model(ts.opt, time_series=args['time_series'], lstm_type=1)
+        ts_model = ts.training(model, [train_x, train_y, valid_x, valid_y])
+        args['Model']['features'] = ','.join(dh.opt.features)
+        args['params'] = json.dumps(args)
+        args['descr'] = '测试'
+        args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
+
+        insert_trained_model_into_mongo(ts_model, args)
+        insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
+        success = 1
+    except Exception as e:
+        my_exception = traceback.format_exc()
+        my_exception.replace("\n", "\t")
+        result['msg'] = my_exception
+    end_time = time.time()
+    result['success'] = success
+    result['args'] = args
+    result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
+    result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
+    print("Program execution ends!")
+    return result
+
+
+if __name__ == "__main__":
+    print("Program starts execution!")
+    from waitress import serve
+    serve(app, host="0.0.0.0", port=10124,
+          threads=8,  # 指定线程数(默认4,根据硬件调整)
+          channel_timeout=600  # 连接超时时间(秒)
+          )
+    print("server start!")
+    # args_dict = {"mongodb_database": 'realtimeDq', 'scaler_table': 'j00600_scaler', 'model_name': 'lstm1',
+    # 'model_table': 'j00600_model', 'mongodb_read_table': 'j00600', 'col_time': 'dateTime',
+    # 'features': 'speed10,direction10,speed30,direction30,speed50,direction50,speed70,direction70,speed90,direction90,speed110,direction110,speed150,direction150,speed170,direction170'}
+    # args_dict['features'] = args_dict['features'].split(',')
+    # args.update(args_dict)
+    # dh = DataHandler(logger, args)
+    # ts = TSHandler(logger, args)
+    # opt = argparse.Namespace(**args)
+    # opt.Model['input_size'] = len(opt.features)
+    # train_data = get_data_from_mongo(args_dict)
+    # train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(train_data)
+    # ts_model = ts.training([train_x, train_y, valid_x, valid_y])
+    #
+    # args_dict['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
+    # args_dict['params'] = args
+    # args_dict['descr'] = '测试'
+    # insert_trained_model_into_mongo(ts_model, args_dict)
+    # insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args_dict)

+ 1 - 1
models_processing/model_tf/tf_tcn.py

@@ -11,7 +11,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoa
 from tensorflow.keras import optimizers, regularizers
 from models_processing.model_tf.losses import region_loss
 import numpy as np
-from common.database_dml_koi import *
+from common.database_dml import *
 from models_processing.model_tf.settings import set_deterministic
 from threading import Lock
 import argparse

+ 1 - 1
models_processing/model_tf/tf_test.py

@@ -12,7 +12,7 @@ from tensorflow.keras import optimizers, regularizers
 from tensorflow.keras.layers import BatchNormalization, GlobalAveragePooling1D, Dropout, Add, Concatenate, Multiply
 from models_processing.model_tf.losses import region_loss
 import numpy as np
-from common.database_dml_koi import *
+from common.database_dml import *
 from models_processing.model_tf.settings import set_deterministic
 from threading import Lock
 import argparse

+ 4 - 3
models_processing/model_tf/tf_test_pre.py

@@ -8,7 +8,7 @@ import json, copy
 import numpy as np
 from flask import Flask, request, g
 import logging, argparse, traceback
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.processing_data_common import missing_features, str_to_list
 from data_processing.data_operation.data_handler import DataHandler
 from threading import Lock
@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 from tf_test import TSHandler
 # logger = Log('tf_bp').logger()
 logger = Log('tf_test').logger
@@ -36,7 +37,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
@@ -81,7 +82,7 @@ def model_prediction_test():
             pre_data.rename(columns={args['col_time']: 'date_time'}, inplace=True)
             res_cols = ['date_time', 'power_forecast', 'farm_id']
         pre_data = pre_data[res_cols]
-        pre_data.loc[:, 'power_forecast'] = pre_data['power_forecast'].round(2)
+        pre_data.loc[:, 'power_forecast'] = pre_data.loc[:, 'power_forecast'].apply(lambda x: float(f"{x:.2f}"))
         pre_data.loc[pre_data['power_forecast'] > float(args['cap']), 'power_forecast'] = float(args['cap'])
         pre_data.loc[pre_data['power_forecast'] < 0, 'power_forecast'] = 0
         insert_data_into_mongo(pre_data, args)

+ 3 - 2
models_processing/model_tf/tf_test_train.py

@@ -13,8 +13,9 @@ from data_processing.data_operation.data_handler import DataHandler
 import time, yaml, threading
 from copy import deepcopy
 from models_processing.model_tf.tf_test import TSHandler
-from common.database_dml_koi import *
+from common.database_dml import *
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_test').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_test_train——service')
@@ -31,7 +32,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 1 - 1
models_processing/model_tf/tf_transformer.py

@@ -11,7 +11,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoa
 from tensorflow.keras import optimizers, regularizers
 from models_processing.model_tf.losses import region_loss
 import numpy as np
-from common.database_dml_koi import *
+from common.database_dml import *
 from models_processing.model_tf.settings import set_deterministic
 from threading import Lock
 import argparse

+ 156 - 75
post_processing/cdq_coe_gen.py

@@ -7,14 +7,13 @@
 import os, requests, json, time, traceback
 import pandas as pd
 import numpy as np
+from bayes_opt import BayesianOptimization
 from common.database_dml_koi import get_data_from_mongo
-from pymongo import MongoClient
-from flask import Flask,request,jsonify, g
+from flask import Flask, request, g
 from datetime import datetime
-# from common.logs import Log
-# logger = Log('post-processing').logger
-from logging import getLogger
-logger = getLogger('xx')
+from common.logs import Log
+
+logger = Log('post-processing').logger
 current_path = os.path.dirname(__file__)
 API_URL = "http://ds2:18080/accuracyAndBiasByJSON"
 app = Flask('cdq_coe_gen——service')
@@ -23,25 +22,29 @@ app = Flask('cdq_coe_gen——service')
 @app.before_request
 def update_config():
     # ------------ 整理参数,整合请求参数 ------------
-    g.coe = {}
+    g.coe = {'T'+str(x):{} for x in range(1, 17)}
 
 
-def iterate_coe(pre_data, point, col_power, col_pre, coe):
+def iterate_coe_simple(pre_data, point, config, coe):
     """
     更新16个点系数
     """
     T = 'T' + str(point + 1)
+    col_pre = config['col_pre']
     best_acc, best_score1, best_coe_m, best_coe_n = 0, 0, 0, 0
-    best_score, best_acc1, best_score_m, best_score_n = 999, 0, 0, 0
-    req_his_fix = prepare_request_body(pre_data, col_power, 'his_fix')
-    req_dq = prepare_request_body(pre_data, col_power, col_pre)
+    best_score, best_acc1, best_score_m, best_score_n = 999, 0, 999, 0
+
+    pre_data = history_error(pre_data, config['col_power'], config['col_pre'], int(coe[T]['hour']//0.25))
+    pre_data = curve_limited(pre_data, config, 'his_fix')
+    req_his_fix = prepare_request_body(pre_data, config, 'his_fix')
+    req_dq = prepare_request_body(pre_data, config, col_pre)
 
     his_fix_acc, his_fix_score = calculate_acc(API_URL, req_his_fix)
     dq_acc, dq_score = calculate_acc(API_URL, req_dq)
     for i in range(5, 210):
         for j in range(5, 210):
             pre_data["new"] = round(i / 170 * pre_data[col_pre] + j / 170 * pre_data['his_fix'], 3)
-            req_new = prepare_request_body(pre_data, col_power, 'new')
+            req_new = prepare_request_body(pre_data, config, 'new')
             acc, acc_score = calculate_acc(API_URL, req_new)
 
             if acc > best_acc:
@@ -57,18 +60,10 @@ def iterate_coe(pre_data, point, col_power, col_pre, coe):
 
     pre_data["coe-acc"] = round(best_coe_m * pre_data[col_pre] + best_coe_n * pre_data['his_fix'], 3)
     pre_data["coe-ass"] = round(best_score_m * pre_data[col_pre] + best_score_n * pre_data['his_fix'], 3)
-    logger.info(
-        "1.过去{} - {}的短期的准确率:{:.4f},自动确认系数后,{} 超短期的准确率:{:.4f},历史功率:{:.4f}".format(
-            pre_data['C_TIME'][0], pre_data['C_TIME'].iloc[-1], dq_acc, T, best_acc, his_fix_acc))
-    logger.info(
-        "2.过去{} - {}的短期的考核分:{:.4f},自动确认系数后,{} 超短期的考核分:{:.4f},历史功率:{:.4f}".format(
-            pre_data['C_TIME'][0], pre_data['C_TIME'].iloc[-1], dq_score, T, best_score1, his_fix_score))
-    logger.info(
-        "3.过去{} - {}的短期的准确率:{:.4f},自动确认系数后,{} 超短期的准确率:{:.4f},历史功率:{:.4f}".format(
-            pre_data['C_TIME'][0], pre_data['C_TIME'].iloc[-1], dq_acc, T, best_acc1, his_fix_acc))
-    logger.info(
-        "4.过去{} - {}的短期的考核分:{:.4f},自动确认系数后,{} 超短期的考核分:{:.4f},历史功率:{:.4f}".format(
-            pre_data['C_TIME'][0], pre_data['C_TIME'].iloc[-1], dq_score, T, best_score, his_fix_score))
+    logger.info("1.过去{} - {}的短期的准确率:{:.4f},自动确认系数后,{} 超短期的准确率:{:.4f},历史功率:{:.4f}".format(pre_data['C_TIME'][0], pre_data['C_TIME'].iloc[-1], dq_acc, T, best_acc, his_fix_acc))
+    logger.info("2.过去{} - {}的短期的考核分:{:.4f},自动确认系数后,{} 超短期的考核分:{:.4f},历史功率:{:.4f}".format(pre_data['C_TIME'][0], pre_data['C_TIME'].iloc[-1], dq_score, T, best_score1, his_fix_score))
+    logger.info("3.过去{} - {}的短期的准确率:{:.4f},自动确认系数后,{} 超短期的准确率:{:.4f},历史功率:{:.4f}".format(pre_data['C_TIME'][0], pre_data['C_TIME'].iloc[-1], dq_acc, T, best_acc1, his_fix_acc))
+    logger.info("4.过去{} - {}的短期的考核分:{:.4f},自动确认系数后,{} 超短期的考核分:{:.4f},历史功率:{:.4f}".format(pre_data['C_TIME'][0], pre_data['C_TIME'].iloc[-1], dq_score, T, best_score, his_fix_score))
 
     coe[T]['score_m'] = round(best_score_m, 3)
     coe[T]['score_n'] = round(best_score_n, 3)
@@ -76,38 +71,124 @@ def iterate_coe(pre_data, point, col_power, col_pre, coe):
     coe[T]['acc_n'] = round(best_coe_n, 3)
     logger.info("系数轮询后,最终调整的系数为:{}".format(coe))
 
-def prepare_request_body(df, col_power, col_pre):
+
+def iterate_coe(pre_data, point, config, coe):
+    """使用贝叶斯优化进行系数寻优"""
+    T = 'T' + str(point + 1)
+    col_pre = config['col_pre']
+    col_time = config['col_time']
+
+    # 历史数据处理(保持原逻辑)
+    pre_data = history_error(pre_data, config['col_power'], config['col_pre'], int(coe[T]['hour'] // 0.25))
+    pre_data = curve_limited(pre_data, config, 'his_fix')
+    req_his_fix = prepare_request_body(pre_data, config, 'his_fix')
+    req_dq = prepare_request_body(pre_data, config, col_pre)
+
+    # 获取基准值(保持原逻辑)
+    his_fix_acc, his_fix_score = calculate_acc(API_URL, req_his_fix)
+    dq_acc, dq_score = calculate_acc(API_URL, req_dq)
+
+    # 定义贝叶斯优化目标函数
+    def evaluate_coefficients(m, n):
+        """评估函数返回准确率和考核分的元组"""
+        local_data = pre_data.copy()
+        local_data["new"] = round(m * local_data[col_pre] + n * local_data['his_fix'], 3)
+        local_data = curve_limited(local_data, config, 'new')
+        req_new = prepare_request_body(local_data, config, 'new')
+        acc, score = calculate_acc(API_URL, req_new)
+        return acc, score
+
+    # 优化准确率
+    def acc_optimizer(m, n):
+        acc, _ = evaluate_coefficients(m, n)
+        return acc
+
+    # 优化考核分
+    def score_optimizer(m, n):
+        _, score = evaluate_coefficients(m, n)
+        return -score  # 取负数因为要最大化负分即最小化原分数
+
+    # 参数空间(保持原参数范围)
+    pbounds = {
+        'm': (5 / 170, 210 / 170),  # 原始范围映射到[0.0294, 1.235]
+        'n': (5 / 170, 210 / 170)
+    }
+
+    # 执行准确率优化
+    acc_bo = BayesianOptimization(f=acc_optimizer, pbounds=pbounds, random_state=42)
+    acc_bo.maximize(init_points=70, n_iter=400) # 增大初始点和迭代次数,捕捉可能的多峰结构
+    best_acc_params = acc_bo.max['params']
+    best_coe_m, best_coe_n = best_acc_params['m'], best_acc_params['n']
+    best_acc = acc_bo.max['target']
+
+    # 执行考核分优化
+    # score_bo = BayesianOptimization(f=score_optimizer, pbounds=pbounds, random_state=42)
+    # score_bo.maximize(init_points=10, n_iter=20)
+    # best_score_params = score_bo.max['params']
+    # best_score_m, best_score_n = best_score_params['m'], best_score_params['n']
+    # best_score = -score_bo.max['target']  # 恢复原始分数
+
+    # 应用最优系数(保持原处理逻辑)
+    pre_data["coe-acc"] = round(best_coe_m * pre_data[col_pre] + best_coe_n * pre_data['his_fix'], 3)
+    # pre_data["coe-ass"] = round(best_score_m * pre_data[col_pre] + best_score_n * pre_data['his_fix'], 3)
+
+    # 记录日志(保持原格式)
+    logger.info("过去{} - {}的短期的准确率:{:.4f},历史功率:{:.4f},自动确认系数后,{} 超短期的准确率:{:.4f}".format(pre_data[col_time][0], pre_data[col_time].iloc[-1], dq_acc, his_fix_acc, T, best_acc))
+
+    # 更新系数表(保持原逻辑)
+    coe[T].update({
+        # 'score_m': round(best_score_m, 3),
+        # 'score_n': round(best_score_n, 3),
+        'acc_m': round(best_coe_m, 3),
+        'acc_n': round(best_coe_n, 3)
+    })
+    logger.info("贝叶斯优化后,最终调整的系数为:{}".format(coe))
+
+def iterate_his_coe(pre_data, point, config, coe):
+    """
+    更新临近时长Δ
+    """
+    T = 'T' + str(point + 1)
+    best_acc, best_hour = 0, 1
+    for hour in np.arange(0.25, 4.25, 0.25):
+        data = pre_data.copy()
+        his_window = int(hour // 0.25)
+        pre_data_f = history_error(data, config['col_power'], config['col_pre'], his_window)
+        pre_data_f = curve_limited(pre_data_f, config, 'his_fix')
+        req_his_fix = prepare_request_body(pre_data_f, config, 'his_fix')
+        his_fix_acc, his_fix_score = calculate_acc(API_URL, req_his_fix)
+
+        if his_fix_acc > best_acc:
+            best_acc = his_fix_acc
+            best_hour = float(round(hour, 2))
+    coe[T]['hour'] = best_hour
+    logger.info(f"{T} 点的最优临近时长:{best_hour}")
+
+def prepare_request_body(df, config, predict):
     """
     准备请求体,动态保留MongoDB中的所有字段
     """
     data = df.copy()
     # 转换时间格式为字符串
-    if 'dateTime' in data.columns and isinstance(data['dateTime'].iloc[0], datetime):
-        data['dateTime'] = data['dateTime'].dt.strftime('%Y-%m-%d %H:%M:%S')
-    data['model'] = col_pre
-    # 排除不需要的字段(如果有)
-    exclude_fields = ['_id']  # 通常排除MongoDB的默认_id字段
-
-    # 获取所有字段名(排除不需要的字段)
-    available_fields = [col for col in data.columns if col not in exclude_fields]
-
-    # 转换为记录列表(保留所有字段)
-    data = data[available_fields].to_dict('records')
-
+    if config['col_time'] in data.columns and isinstance(data[config['col_time']].iloc[0], datetime):
+        data[config['col_time'] ] = data[config['col_time'] ].dt.strftime('%Y-%m-%d %H:%M:%S')
+    data['model'] = predict
+    # 保留必要的字段
+    data = data[[config['col_time'], config['col_power'], predict, 'model']].to_dict('records')
     # 构造请求体(固定部分+动态数据部分)
     request_body = {
-        "stationCode": "J00600",
-        "realPowerColumn": col_power,
-        "ablePowerColumn": col_power,
-        "predictPowerColumn": col_pre,
-        "inStalledCapacityName": 153,
+        "stationCode": config['stationCode'],
+        "realPowerColumn": config['col_power'],
+        "ablePowerColumn": config['col_power'],
+        "predictPowerColumn": predict,
+        "inStalledCapacityName": config['inStalledCapacityName'],
         "computTypeEnum": "E2",
-        "computMeasEnum": "E2",
-        "openCapacityName": 153,
-        "onGridEnergy": 0,
-        "price": 0,
-        "fault": -99,
-        "colTime": "dateTime",  #时间列名(可选,要与上面'dateTime一致')
+        "computMeasEnum": config.get('computMeasEnum', 'E2'),
+        "openCapacityName": config['openCapacityName'],
+        "onGridEnergy": config.get('onGridEnergy', 1),
+        "price": config.get('price', 1),
+        "fault": config.get('fault', -99),
+        "colTime": config['col_time'],  #时间列名(可选,要与上面'dateTime一致')
         # "computPowersEnum": "E4"  # 计算功率类型(可选)
         "data": data  # MongoDB数据
     }
@@ -132,32 +213,35 @@ def calculate_acc(api_url, request_body):
         if response.status_code == 200:
             acc = np.average([res['accuracy'] for res in result])
             # ass = np.average([res['accuracyAssessment'] for res in result])
-            print("111111111")
             return acc, 0
         else:
-            logger.info(f"失败:{result['status']},{result['error']}")
-            print(f"失败:{result['status']},{result['error']}")
-            print("22222222")
+            logger.info(f"{response.status_code}失败:{result['status']},{result['error']}")
     except requests.exceptions.RequestException as e:
-        print(f"API调用失败: {e}")
-        print("333333333")
+        logger.info(f"准确率接口调用失败: {e}")
         return None
 
-def history_error(data, col_power, col_pre):
+def history_error(data, col_power, col_pre, his_window):
     data['error'] =  data[col_power] - data[col_pre]
     data['error'] = data['error'].round(2)
     data.reset_index(drop=True, inplace=True)
     # 用前面5个点的平均error,和象心力相加
-    numbers = len(data) - 5
-    datas = [data.iloc[x: x+5, :].reset_index(drop=True) for x in range(0, numbers)]
-    data_error = [np.mean(d.iloc[0:5, -1]) for d in datas]
-    pad_data_error = np.pad(data_error, (5, 0), mode='constant', constant_values=0)
+    numbers = len(data) - his_window
+    datas = [data.iloc[x: x+his_window, :].reset_index(drop=True) for x in range(0, numbers)]
+    data_error = [np.mean(d.iloc[0:his_window, -1]) for d in datas]
+    pad_data_error = np.pad(data_error, (his_window, 0), mode='constant', constant_values=0)
     data['his_fix'] = data[col_pre] + pad_data_error
-    data = data.iloc[5:, :].reset_index(drop=True)
-    data.loc[data[col_pre] <= 0, ['his_fix']] = 0
-    data['dateTime'] = pd.to_datetime(data['dateTime'])
-    data = data.loc[:, ['dateTime', col_power, col_pre, 'his_fix']]
-    # data.to_csv('J01080原始数据.csv', index=False)
+    data = data.iloc[his_window:, :].reset_index(drop=True)
+    return data
+
+def curve_limited(pre_data, config, predict):
+    """
+    plant_type: 0 风 1 光
+    """
+    data = pre_data.copy()
+    col_time, cap = config['col_time'], float(config['openCapacityName'])
+    data[col_time] = pd.to_datetime(data[col_time])
+    data.loc[data[predict] < 0, [predict]] = 0
+    data.loc[data[predict] > cap, [predict]] = cap
     return data
 
 @app.route('/cdq_coe_gen', methods=['POST'])
@@ -171,10 +255,10 @@ def get_station_cdq_coe():
     try:
         args = request.values.to_dict()
         logger.info(args)
-        data = get_data_from_mongo(args).sort_values(by='dateTime', ascending=True)
-        pre_data = history_error(data, col_power='realPower', col_pre='dq')
+        data = get_data_from_mongo(args).sort_values(by=args['col_time'], ascending=True)
         for point in range(0, 16, 1):
-            iterate_coe(pre_data, point, 'realPower', 'dq', coe)
+            iterate_his_coe(data, point, args, coe)
+            iterate_coe(data, point, args, coe)
         success = 1
     except Exception as e:
         my_exception = traceback.format_exc()
@@ -203,11 +287,8 @@ if __name__ == "__main__":
     # run_code = 0
     print("Program starts execution!")
     from waitress import serve
-
-    serve(
-        app,
-        host="0.0.0.0",
-        port=10123,
-        threads=8,  # 指定线程数(默认4,根据硬件调整)
-        channel_timeout=600  # 连接超时时间(秒)
-    )
+    serve(app, host="0.0.0.0", port=10123,
+          threads=8,  # 指定线程数(默认4,根据硬件调整)
+          channel_timeout=600  # 连接超时时间(秒)
+          )
+    print("server start!")

+ 4 - 1
requirements.txt

@@ -16,4 +16,7 @@ protobuf==3.20.3
 APScheduler==3.10.4
 paramiko==3.5.0
 PyYAML==6.0.1
-keras==3.8.0
+keras==3.8.0
+toml==0.10.2
+JayDeBeApi==1.2.3
+jpype1==1.5.2

+ 2 - 6
run_all.py

@@ -18,10 +18,6 @@ services = [
     ("models_processing/model_predict/model_prediction_lightgbm.py", 10090),
     ("models_processing/model_train/model_training_lstm.py", 10096),
     ("models_processing/model_predict/model_prediction_lstm.py", 10097),
-    ("models_processing/model_train/model_training_ml.py", 10126),
-    ("models_processing/model_predict/model_prediction_ml.py", 10127),
-
-
     ("models_processing/model_tf/tf_bp_pre.py", 10110),
     ("models_processing/model_tf/tf_bp_train.py", 10111),
     ("models_processing/model_tf/tf_cnn_pre.py", 10112),
@@ -37,7 +33,6 @@ services = [
     ("models_processing/model_tf/tf_lstm_zone_pre.py", 10125),
     ("models_processing/model_tf/tf_lstm_zone_train.py", 10124),
 
-
     ("post_processing/post_processing.py", 10098),
     ("evaluation_processing/analysis.py", 10099),
     ("models_processing/model_predict/res_prediction.py", 10105),
@@ -48,7 +43,8 @@ services = [
     ("data_processing/data_operation/data_tj_nwp_ftp.py", 10106),
     ("post_processing/pre_post_processing.py", 10107),
     ("post_processing/cdq_coe_gen.py", 10123),
-    ("post_processing/post_process.py", 10128),
+    ("models_processing/model_predict/model_prediction_photovoltaic_physical.py", 10126),
+    ("data_processing/data_operation/hive_to_mongo.py", 10127),
 ]
 
 # 获取当前脚本所在的根目录