David vor 1 Monat
Ursprung
Commit
a00b102627

+ 29 - 21
common/database_dml_koi.py

@@ -189,9 +189,17 @@ def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any])
         collection = db[args['model_table']]
 
         # ------------------------- 索引检查 -------------------------
-        if "gen_time_1" not in collection.index_information():
-            collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
-            print("已创建时间索引")
+        # index_info = collection.index_information()
+        # if "gen_time_1" not in index_info:
+        #     print("开始创建索引...")
+        #     collection.create_index(
+        #         [("gen_time", ASCENDING)],
+        #         name="gen_time_1",
+        #         background=True
+        #     )
+        #     print("索引创建成功")
+        # else:
+        #     print("索引已存在,跳过创建")
 
         # ------------------------- 容量控制 -------------------------
         # 使用更高效的计数方式
@@ -271,9 +279,9 @@ def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_
         collection = db[args['scaler_table']]
 
         # ------------------------- 索引维护 -------------------------
-        if "gen_time_1" not in collection.index_information():
-            collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
-            print("⏱️ 已创建时间排序索引")
+        # if "gen_time_1" not in collection.index_information():
+        #     collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
+        #     print("⏱️ 已创建时间排序索引")
 
         # ------------------------- 容量控制 -------------------------
         # 使用近似计数提升性能(误差在几十条内可接受)
@@ -448,13 +456,13 @@ def get_keras_model_from_mongo(
         collection = db[args['model_table']]
 
         # ------------------------- 索引维护 -------------------------
-        index_name = "model_gen_time_idx"
-        if index_name not in collection.index_information():
-            collection.create_index(
-                [("model_name", 1), ("gen_time", DESCENDING)],
-                name=index_name
-            )
-            print("⏱️ 已创建复合索引")
+        # index_name = "model_gen_time_idx"
+        # if index_name not in collection.index_information():
+        #     collection.create_index(
+        #         [("model_name", 1), ("gen_time", DESCENDING)],
+        #         name=index_name
+        #     )
+        #     print("⏱️ 已创建复合索引")
 
         # ------------------------- 高效查询 -------------------------
         model_doc = collection.find_one(
@@ -544,14 +552,14 @@ def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool
         collection = db[args['scaler_table']]
 
         # ------------------------- 索引维护 -------------------------
-        index_name = "model_gen_time_idx"
-        if index_name not in collection.index_information():
-            collection.create_index(
-                [("model_name", 1), ("gen_time", DESCENDING)],
-                name=index_name,
-                background=True  # 后台构建避免阻塞
-            )
-            print("⏱️ 已创建特征缩放器复合索引")
+        # index_name = "model_gen_time_idx"
+        # if index_name not in collection.index_information():
+        #     collection.create_index(
+        #         [("model_name", 1), ("gen_time", DESCENDING)],
+        #         name=index_name,
+        #         background=True  # 后台构建避免阻塞
+        #     )
+        #     print("⏱️ 已创建特征缩放器复合索引")
 
         # ------------------------- 高效查询 -------------------------
         scaler_doc = collection.find_one(

+ 2 - 2
models_processing/model_tf/tf_bp.py

@@ -32,7 +32,7 @@ class BPHandler(object):
         try:
             with model_lock:
                 # loss = region_loss(self.opt)
-                self.model = get_h5_model_from_mongo(args)
+                self.model = get_keras_model_from_mongo(args)
         except Exception as e:
             self.logger.info("加载模型权重失败:{}".format(e.args))
 
@@ -89,7 +89,7 @@ class BPHandler(object):
             if self.opt.Model['add_train']:
                 # 进行加强训练,支持修模
                 loss = region_loss(self.opt)
-                base_train_model = get_h5_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
+                base_train_model = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
                 base_train_model.summary()
                 self.logger.info("已加载加强训练基础模型")
             else:

+ 2 - 2
models_processing/model_tf/tf_cnn.py

@@ -32,7 +32,7 @@ class CNNHandler(object):
         try:
             with model_lock:
                 loss = region_loss(self.opt)
-                self.model = get_h5_model_from_mongo(args, {type(loss).__name__: loss})
+                self.model = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
         except Exception as e:
             self.logger.info("加载模型权重失败:{}".format(e.args))
 
@@ -61,7 +61,7 @@ class CNNHandler(object):
             if self.opt.Model['add_train']:
                 # 进行加强训练,支持修模
                 loss = region_loss(self.opt)
-                base_train_model = get_h5_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
+                base_train_model = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
                 base_train_model.summary()
                 self.logger.info("已加载加强训练基础模型")
             else: