David 2 月之前
父節點
當前提交
1b03c8413f

+ 347 - 128
common/database_dml_koi.py

@@ -1,11 +1,15 @@
-from pymongo import MongoClient, UpdateOne, DESCENDING
+import pymongo
+from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING
+from pymongo.errors import PyMongoError
 import pandas as pd
 from sqlalchemy import create_engine
 import pickle
 from io import BytesIO
 import joblib
-import h5py
+import h5py, os, io
 import tensorflow as tf
+from typing import Dict, Any, Optional, Union, Tuple
+import tempfile
 
 def get_data_from_mongo(args):
     mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
@@ -141,133 +145,348 @@ def insert_pickle_model_into_mongo(model, args):
     print("model inserted successfully!")
 
 
-def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,args):
-    mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
-                                args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
-    client = MongoClient(mongodb_connection)
-    db = client[mongodb_database]
-    if scaler_table in db.list_collection_names():
-        db[scaler_table].drop()
-        print(f"Collection '{scaler_table} already exist, deleted successfully!")
-    collection = db[scaler_table]  # 集合名称
-    # Save the scalers in MongoDB as binary data
-    collection.insert_one({
-        "feature_scaler": feature_scaler_bytes.read(),
-        "target_scaler": target_scaler_bytes.read()
-    })
-    print("scaler_model inserted successfully!")
-    if model_table in db.list_collection_names():
-        db[model_table].drop()
-        print(f"Collection '{model_table} already exist, deleted successfully!")
-    model_table = db[model_table]
-    # 创建 BytesIO 缓冲区
-    model_buffer = BytesIO()
-    # 将模型保存为 HDF5 格式到内存 (BytesIO)
-    model.save(model_buffer, save_format='h5')
-    # 将指针移到缓冲区的起始位置
-    model_buffer.seek(0)
-    # 获取模型的二进制数据
-    model_data = model_buffer.read()
-    # 将模型保存到 MongoDB
-    model_table.insert_one({
-        "model_name": model_name,
-        "model_data": model_data
-    })
-    print("模型成功保存到 MongoDB!")
-
-def insert_trained_model_into_mongo(model, args):
-    mongodb_connection,mongodb_database,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
-                                args['mongodb_database'],args['model_table'],args['model_name'])
-
-    gen_time, params_json, descr = args['gen_time'], args['params'], args['descr']
-    client = MongoClient(mongodb_connection)
-    db = client[mongodb_database]
-    if model_table in db.list_collection_names():
-        db[model_table].drop()
-        print(f"Collection '{model_table} already exist, deleted successfully!")
-    model_table = db[model_table]
-    # 创建 BytesIO 缓冲区
-    model_buffer = BytesIO()
-    # 将模型保存为 HDF5 格式到内存 (BytesIO)
-    model.save(model_buffer, save_format='h5')
-    # 将指针移到缓冲区的起始位置
-    model_buffer.seek(0)
-    # 获取模型的二进制数据
-    model_data = model_buffer.read()
-    # 将模型保存到 MongoDB
-    model_table.insert_one({
-        "model_name": model_name,
-        "model_data": model_data,
-        "gen_time": gen_time,
-        "params": params_json,
-        "descr": descr
-    })
-    print("模型成功保存到 MongoDB!")
-
-def insert_scaler_model_into_mongo(feature_scaler_bytes, scaled_target_bytes, args):
-    mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
-                                args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
-    gen_time = args['gen_time']
-    client = MongoClient(mongodb_connection)
-    db = client[mongodb_database]
-    if scaler_table in db.list_collection_names():
-        db[scaler_table].drop()
-        print(f"Collection '{scaler_table} already exist, deleted successfully!")
-    collection = db[scaler_table]  # 集合名称
-    # Save the scalers in MongoDB as binary data
-    collection.insert_one({
-        "model_name": model_name,
-        "gent_time": gen_time,
-        "feature_scaler": feature_scaler_bytes.read(),
-        "target_scaler": scaled_target_bytes.read()
-    })
-    print("scaler_model inserted successfully!")
-
-
-def get_h5_model_from_mongo(args, custom=None):
-    mongodb_connection,mongodb_database,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['model_table'],args['model_name']
-    client = MongoClient(mongodb_connection)
-    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
-    db = client[mongodb_database]
-    collection = db[model_table]  # 集合名称
-
-     # 查询 MongoDB 获取模型数据
-    model_doc = collection.find_one({"model_name": model_name}, sort=[('gen_time', DESCENDING)])
-    if model_doc:
-        model_data = model_doc['model_data']  # 获取模型的二进制数据
-        # 将二进制数据加载到 BytesIO 缓冲区
-        model_buffer = BytesIO(model_data)
-        # 从缓冲区加载模型
-         # 使用 h5py 和 BytesIO 从内存中加载模型
-        with h5py.File(model_buffer, 'r') as f:
-            model = tf.keras.models.load_model(f, custom_objects=custom)
-        print(f"{model_name}模型成功从 MongoDB 加载!")
-        client.close()
-        return model
-    else:
-        print(f"未找到model_name为 {model_name} 的模型。")
-        client.close()
-        return None
+def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any]) -> str:
+    """
+    将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型
+    参数:
+    model : keras模型 - 训练好的Keras模型
+    args : dict - 包含以下键的字典:
+        mongodb_database: 数据库名称
+        model_table: 集合名称
+        model_name: 模型名称
+        gen_time: 模型生成时间(datetime对象)
+        params: 模型参数(JSON可序列化对象)
+        descr: 模型描述文本
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'model_table', 'model_name',
+                     'gen_time', 'params', 'descr'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"缺少必要参数: {missing}")
+
+    # ------------------------- 配置解耦 -------------------------
+    # 从环境变量获取连接信息(更安全)
+    mongodb_connection = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
+
+    # ------------------------- 资源初始化 -------------------------
+    fd, temp_path = None, None
+    client = None
+
+    try:
+        # ------------------------- 临时文件处理 -------------------------
+        fd, temp_path = tempfile.mkstemp(suffix='.h5')
+        os.close(fd)  # 立即释放文件锁
+
+        # ------------------------- 模型保存 -------------------------
+        try:
+            model.save(temp_path, save_format='h5')
+        except Exception as e:
+            raise RuntimeError(f"模型保存失败: {str(e)}") from e
+
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(mongodb_connection)
+        db = client[args['mongodb_database']]
+        collection = db[args['model_table']]
+
+        # ------------------------- 索引检查 -------------------------
+        if "gen_time_1" not in collection.index_information():
+            collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
+            print("已创建时间索引")
 
+        # ------------------------- 容量控制 -------------------------
+        # 使用更高效的计数方式
+        if collection.estimated_document_count() >= 50:
+            # 原子性删除操作
+            if deleted := collection.find_one_and_delete(
+                    sort=[("gen_time", ASCENDING)],
+                    projection={"_id": 1, "model_name": 1, "gen_time": 1}
+            ):
+                print(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
 
-def get_scaler_model_from_mongo(args, only_feature_scaler=False):
+        # ------------------------- 数据插入 -------------------------
+        with open(temp_path, 'rb') as f:
+            result = collection.insert_one({
+                "model_name": args['model_name'],
+                "model_data": f.read(),
+                "gen_time": args['gen_time'],
+                "params": args['params'],
+                "descr": args['descr']
+            })
+
+        print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
+        return str(result.inserted_id)
+
+    except Exception as e:
+        # ------------------------- 异常分类处理 -------------------------
+        error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
+        print(f"❌ {error_type} - 详细错误: {str(e)}")
+        raise  # 根据业务需求决定是否重新抛出
+
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+        if temp_path and os.path.exists(temp_path):
+            try:
+                os.remove(temp_path)
+            except PermissionError:
+                print(f"⚠️ 临时文件清理失败: {temp_path}")
+
+
+def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO, args: Dict[str, Any]) -> str:
     """
-    根据模 型名称版本 和 生成时间 获取模型
+    将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档
+
+    参数:
+    feature_scaler_bytes: BytesIO - 特征缩放器字节流
+    scaled_target_bytes: BytesIO - 目标缩放器字节流
+    args : dict - 包含以下键的字典:
+        mongodb_database: 数据库名称
+        scaler_table: 集合名称
+        model_name: 关联模型名称
+        gen_time: 生成时间(datetime对象)
     """
-    mongodb_connection, mongodb_database, scaler_table, = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'], args['scaler_table'])
-    model_name, gen_time = args['model_name'], args['gent_time']
-    client = MongoClient(mongodb_connection)
-    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
-    db = client[mongodb_database]
-    collection = db[scaler_table]  # 集合名称
-    # Retrieve the scalers from MongoDB
-    scaler_doc = collection.find_one({"model_name": model_name, "gen_time": gen_time})
-    # Deserialize the scalers
-
-    feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"])
-    feature_scaler = joblib.load(feature_scaler_bytes)
-    if only_feature_scaler:
-        return feature_scaler
-    target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
-    target_scaler = joblib.load(target_scaler_bytes)
-    return feature_scaler,target_scaler
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"缺少必要参数: {missing}")
+
+    # ------------------------- 配置解耦 -------------------------
+    # 从环境变量获取连接信息(安全隔离凭证)
+    mongodb_conn = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
+
+    # ------------------------- 输入验证 -------------------------
+    for buf, name in [(feature_scaler_bytes, "特征缩放器"),
+                      (target_scaler_bytes, "目标缩放器")]:
+        if not isinstance(buf, BytesIO):
+            raise TypeError(f"{name} 必须为BytesIO类型")
+        if buf.getbuffer().nbytes == 0:
+            raise ValueError(f"{name} 字节流为空")
+
+    client = None
+    try:
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(mongodb_conn)
+        db = client[args['mongodb_database']]
+        collection = db[args['scaler_table']]
+
+        # ------------------------- 索引维护 -------------------------
+        if "gen_time_1" not in collection.index_information():
+            collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
+            print("⏱️ 已创建时间排序索引")
+
+        # ------------------------- 容量控制 -------------------------
+        # 使用近似计数提升性能(误差在几十条内可接受)
+        if collection.estimated_document_count() >= 50:
+            # 原子性删除操作(保证事务完整性)
+            if deleted := collection.find_one_and_delete(
+                    sort=[("gen_time", ASCENDING)],
+                    projection={"_id": 1, "model_name": 1, "gen_time": 1}
+            ):
+                print(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
+
+        # ------------------------- 数据插入 -------------------------
+        # 确保字节流指针位置正确
+        feature_scaler_bytes.seek(0)
+        target_scaler_bytes.seek(0)
+
+        result = collection.insert_one({
+            "model_name": args['model_name'],
+            "gen_time": args['gen_time'],
+            "feature_scaler": feature_scaler_bytes.read(),
+            "target_scaler": target_scaler_bytes.read()
+        })
+
+        print(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
+        return str(result.inserted_id)
+
+    except Exception as e:
+        # ------------------------- 异常分类处理 -------------------------
+        error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误"
+        print(f"❌ {error_type}异常 - 详细错误: {str(e)}")
+        raise  # 根据业务需求决定是否重新抛出
+
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+        # 重置字节流指针(确保后续可复用)
+        feature_scaler_bytes.seek(0)
+        target_scaler_bytes.seek(0)
+
+
+def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[tf.keras.Model]:
+    """
+    从MongoDB获取指定模型的最新版本
+
+    参数:
+    args : dict - 包含以下键的字典:
+        mongodb_database: 数据库名称
+        model_table: 集合名称
+        model_name: 要获取的模型名称
+    custom_objects: dict - 自定义Keras对象字典
+
+    返回:
+    tf.keras.Model - 加载成功的Keras模型
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'model_table', 'model_name'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"❌ 缺失必要参数: {missing}")
+
+    # ------------------------- 环境配置 -------------------------
+    mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
+    client = None
+    try:
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(
+            mongo_uri,
+            maxPoolSize=10,  # 连接池优化
+            socketTimeoutMS=5000
+        )
+        db = client[args['mongodb_database']]
+        collection = db[args['model_table']]
+
+        # ------------------------- 索引维护 -------------------------
+        index_name = "model_gen_time_idx"
+        if index_name not in collection.index_information():
+            collection.create_index(
+                [("model_name", 1), ("gen_time", DESCENDING)],
+                name=index_name
+            )
+            print("⏱️ 已创建复合索引")
+
+        # ------------------------- 高效查询 -------------------------
+        model_doc = collection.find_one(
+            {"model_name": args['model_name']},
+            sort=[('gen_time', DESCENDING)],
+            projection={"model_data": 1, "gen_time": 1}  # 获取必要字段
+        )
+
+        if not model_doc:
+            print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
+            return None
+
+        # ------------------------- 内存优化加载 -------------------------
+        if model_doc:
+            model_data = model_doc['model_data']  # 获取模型的二进制数据
+            # 将二进制数据加载到 BytesIO 缓冲区
+            model_buffer = BytesIO(model_data)
+            # 从缓冲区加载模型
+            # 使用 h5py 和 BytesIO 从内存中加载模型
+            with h5py.File(model_buffer, 'r') as f:
+                model = tf.keras.models.load_model(f, custom_objects=custom_objects)
+            print(f"{args['model_name']}模型成功从 MongoDB 加载!")
+            return model
+    except tf.errors.NotFoundError as e:
+        print(f"❌ 模型结构缺失关键组件: {str(e)}")
+        raise RuntimeError("模型架构不完整") from e
+
+    except Exception as e:
+        print(f"❌ 系统异常: {str(e)}")
+        raise
+
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+
+
+def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]:
+    """
+    优化版特征缩放器加载函数 - 安全高效获取最新预处理模型
+
+    参数:
+    args : 必须包含键:
+        - mongodb_database: 数据库名称
+        - scaler_table: 集合名称
+        - model_name: 目标模型名称
+    only_feature_scaler : 是否仅返回特征缩放器
+
+    返回:
+    单个缩放器对象或(feature_scaler, target_scaler)元组
+
+    异常:
+    ValueError : 参数缺失或类型错误
+    RuntimeError : 数据操作异常
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'scaler_table', 'model_name'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"❌ 缺失必要参数: {missing}")
+
+    # ------------------------- 环境配置 -------------------------
+    mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
+
+    client = None
+    try:
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(
+            mongo_uri,
+            maxPoolSize=20,  # 连接池上限
+            socketTimeoutMS=3000,  # 3秒超时
+            serverSelectionTimeoutMS=5000  # 5秒服务器选择超时
+        )
+        db = client[args['mongodb_database']]
+        collection = db[args['scaler_table']]
+
+        # ------------------------- 索引维护 -------------------------
+        index_name = "model_gen_time_idx"
+        if index_name not in collection.index_information():
+            collection.create_index(
+                [("model_name", 1), ("gen_time", DESCENDING)],
+                name=index_name,
+                background=True  # 后台构建避免阻塞
+            )
+            print("⏱️ 已创建特征缩放器复合索引")
+
+        # ------------------------- 高效查询 -------------------------
+        scaler_doc = collection.find_one(
+            {"model_name": args['model_name']},
+            sort=[('gen_time', DESCENDING)],
+            projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1}
+        )
+
+        if not scaler_doc:
+            raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录")
+
+        # ------------------------- 反序列化处理 -------------------------
+        def load_scaler(data: bytes) -> object:
+            """安全加载序列化对象"""
+            with BytesIO(data) as buffer:
+                buffer.seek(0)  # 确保指针复位
+                try:
+                    return joblib.load(buffer)
+                except joblib.UnpicklingError as e:
+                    raise RuntimeError("反序列化失败 (可能版本不兼容)") from e
+
+        # 特征缩放器加载
+        feature_data = scaler_doc["feature_scaler"]
+        if not isinstance(feature_data, bytes):
+            raise RuntimeError("特征缩放器数据格式异常")
+        feature_scaler = load_scaler(feature_data)
+
+        if only_feature_scaler:
+            return feature_scaler
+
+        # 目标缩放器加载
+        target_data = scaler_doc["target_scaler"]
+        if not isinstance(target_data, bytes):
+            raise RuntimeError("目标缩放器数据格式异常")
+        target_scaler = load_scaler(target_data)
+
+        print(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})")
+        return feature_scaler, target_scaler
+
+    except PyMongoError as e:
+        raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e
+    except RuntimeError as e:
+        raise  # 直接传递已封装的异常
+    except Exception as e:
+        raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+
+

+ 1 - 2
models_processing/losses/loss_cdq.py

@@ -4,13 +4,12 @@
 # file: loss.py.py
 # author: David
 # company: shenyang JY
-from keras import backend as K
 import tensorflow as tf
 tf.compat.v1.set_random_seed(1234)
 
 
 def rmse(y_true, y_pred):
-    return K.sqrt(K.mean(K.square(y_pred - y_true)))
+    return tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true)))
 
 class SouthLoss(tf.keras.losses.Loss):
     def __init__(self, opt, name='south_loss'):

+ 1 - 1
models_processing/model_koi/tf_bp.py

@@ -11,7 +11,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoa
 from tensorflow.keras import optimizers, regularizers
 from models_processing.losses.loss_cdq import rmse
 import numpy as np
-from common.database_dml import *
+from common.database_dml_koi import *
 from threading import Lock
 import argparse
 model_lock = Lock()

+ 2 - 1
models_processing/model_koi/tf_bp_pre.py

@@ -8,7 +8,7 @@ import json, copy
 import numpy as np
 from flask import Flask, request
 import logging, argparse, traceback
-from common.database_dml import *
+from common.database_dml_koi import *
 from common.processing_data_common import missing_features, str_to_list
 from data_processing.data_operation.data_handler import DataHandler
 from threading import Lock
@@ -53,6 +53,7 @@ def model_prediction_bp():
         pre_data = get_data_from_mongo(args)
         feature_scaler, target_scaler = get_scaler_model_from_mongo(args)
         scaled_pre_x = dh.pre_data_handler(pre_data, feature_scaler, bp_data=True)
+        # ------------ 获取模型,预测结果------------
         bp.get_model(args)
         res = list(chain.from_iterable(target_scaler.inverse_transform([bp.predict(scaled_pre_x).flatten()])))
         pre_data['power_forecast'] = res[:len(pre_data)]

+ 4 - 3
models_processing/model_koi/tf_bp_train.py

@@ -13,7 +13,7 @@ import logging, argparse
 from data_processing.data_operation.data_handler import DataHandler
 import time, yaml
 from models_processing.model_koi.tf_bp import BPHandler
-from common.database_dml import *
+from common.database_dml_koi import *
 import matplotlib.pyplot as plt
 from common.logs import Log
 logger = logging.getLogger()
@@ -26,7 +26,7 @@ with app.app_context():
     with open('../model_koi/bp.yaml', 'r', encoding='utf-8') as f:
         args = yaml.safe_load(f)
     dh = DataHandler(logger, args)
-    bp = BPHandler(logger)
+    bp = BPHandler(logger, args)
     global opt
 
 @app.before_request
@@ -51,9 +51,10 @@ def model_training_bp():
     # ------------ 获取数据,预处理训练数据 ------------
     train_data = get_data_from_mongo(args)
     train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(train_data, bp_data=True)
-    # ------------ 训练模型,保存模型 ------------
+    # ------------ 训练模型 ------------
     bp.opt.Model['input_size'] = train_x.shape[1]
     bp_model = bp.training([train_x, train_y, valid_x, valid_y])
+    # ------------ 保存模型 ------------
     args['params'] = json.dumps(args)
     args['descr'] = '测试'
     args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))