David 1 kuukausi sitten
vanhempi
commit
788c3bc2e3
2 muutettua tiedostoa jossa 108 lisäystä ja 5 poistoa
  1. 106 3
      common/database_dml_koi.py
  2. 2 2
      models_processing/model_tf/tf_lstm.py

+ 106 - 3
common/database_dml_koi.py

@@ -174,12 +174,12 @@ def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any])
 
     try:
         # ------------------------- 临时文件处理 -------------------------
-        fd, temp_path = tempfile.mkstemp(suffix='.h5')
+        fd, temp_path = tempfile.mkstemp(suffix='.keras')
         os.close(fd)  # 立即释放文件锁
 
         # ------------------------- 模型保存 -------------------------
         try:
-            model.save(temp_path, save_format='h5')
+            model.save(temp_path) # 不指定save_format,默认使用keras新格式
         except Exception as e:
             raise RuntimeError(f"模型保存失败: {str(e)}") from e
 
@@ -337,6 +337,7 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
     # ------------------------- 环境配置 -------------------------
     mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
     client = None
+    tmp_file_path = None  # 用于跟踪临时文件路径
     try:
         # ------------------------- 数据库连接 -------------------------
         client = MongoClient(
@@ -379,7 +380,7 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
             # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
             #     model = tf.keras.models.load_model(f, custom_objects=custom_objects)
             # 创建临时文件
-            with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp_file:
+            with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
                 tmp_file.write(model_data)
                 tmp_file_path = tmp_file.name  # 获取临时文件路径
 
@@ -400,7 +401,109 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
         # ------------------------- 资源清理 -------------------------
         if client:
             client.close()
+        # 确保删除临时文件
+        if tmp_file_path and os.path.exists(tmp_file_path):
+            try:
+                os.remove(tmp_file_path)
+                print(f"🧹 已清理临时文件: {tmp_file_path}")
+            except Exception as cleanup_err:
+                print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
+
+
+def get_keras_model_from_mongo(
+        args: Dict[str, Any],
+        custom_objects: Optional[Dict[str, Any]] = None
+) -> Optional[tf.keras.Model]:
+    """
+    从MongoDB获取指定模型的最新版本(支持Keras格式)
+
+    参数:
+    args : dict - 包含以下键的字典:
+        mongodb_database: 数据库名称
+        model_table: 集合名称
+        model_name: 要获取的模型名称
+    custom_objects: dict - 自定义Keras对象字典
+
+    返回:
+    tf.keras.Model - 加载成功的Keras模型
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'model_table', 'model_name'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"❌ 缺失必要参数: {missing}")
+
+    # ------------------------- 环境配置 -------------------------
+    mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
+    client = None
+    tmp_file_path = None  # 用于跟踪临时文件路径
+
+    try:
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(
+            mongo_uri,
+            maxPoolSize=10,
+            socketTimeoutMS=5000
+        )
+        db = client[args['mongodb_database']]
+        collection = db[args['model_table']]
 
+        # ------------------------- 索引维护 -------------------------
+        index_name = "model_gen_time_idx"
+        if index_name not in collection.index_information():
+            collection.create_index(
+                [("model_name", 1), ("gen_time", DESCENDING)],
+                name=index_name
+            )
+            print("⏱️ 已创建复合索引")
+
+        # ------------------------- 高效查询 -------------------------
+        model_doc = collection.find_one(
+            {"model_name": args['model_name']},
+            sort=[('gen_time', DESCENDING)],
+            projection={"model_data": 1, "gen_time": 1}
+        )
+
+        if not model_doc:
+            print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
+            return None
+
+        # ------------------------- 内存优化加载 -------------------------
+        model_data = model_doc['model_data']
+
+        # 创建临时文件(自动删除)
+        with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
+            tmp_file.write(model_data)
+            tmp_file_path = tmp_file.name  # 记录文件路径
+
+        # 从临时文件加载模型
+        model = tf.keras.models.load_model(
+            tmp_file_path,
+            custom_objects=custom_objects
+        )
+
+        print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
+        return model
+
+    except tf.errors.NotFoundError as e:
+        print(f"❌ 模型结构缺失关键组件: {str(e)}")
+        raise RuntimeError("模型架构不完整") from e
+
+    except Exception as e:
+        print(f"❌ 系统异常: {str(e)}")
+        raise
+
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+
+        # 确保删除临时文件
+        if tmp_file_path and os.path.exists(tmp_file_path):
+            try:
+                os.remove(tmp_file_path)
+                print(f"🧹 已清理临时文件: {tmp_file_path}")
+            except Exception as cleanup_err:
+                print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
 
 def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]:
     """

+ 2 - 2
models_processing/model_tf/tf_lstm.py

@@ -31,7 +31,7 @@ class TSHandler(object):
         try:
             with model_lock:
                 loss = region_loss(self.opt)
-                self.model = get_h5_model_from_mongo(args, {type(loss).__name__: loss})
+                self.model = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
         except Exception as e:
             self.logger.info("加载模型权重失败:{}".format(e.args))
 
@@ -58,7 +58,7 @@ class TSHandler(object):
             if self.opt.Model['add_train']:
                 # 进行加强训练,支持修模
                 loss = region_loss(self.opt)
-                base_train_model = get_h5_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
+                base_train_model = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
                 base_train_model.summary()
                 self.logger.info("已加载加强训练基础模型")
             else: