David 1 hónapja
szülő
commit
57fda6a5d7

+ 3 - 3
common/database_dml_koi.py

@@ -468,7 +468,7 @@ def get_keras_model_from_mongo(
         model_doc = collection.find_one(
             {"model_name": args['model_name']},
             sort=[('gen_time', DESCENDING)],
-            projection={"model_data": 1, "gen_time": 1}
+            projection={"model_data": 1, "gen_time": 1, 'params':1}
         )
 
         if not model_doc:
@@ -477,7 +477,7 @@ def get_keras_model_from_mongo(
 
         # ------------------------- 内存优化加载 -------------------------
         model_data = model_doc['model_data']
-
+        model_params = model_doc['params']
         # 创建临时文件(自动删除)
         with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
             tmp_file.write(model_data)
@@ -490,7 +490,7 @@ def get_keras_model_from_mongo(
         )
 
         print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
-        return model
+        return model, model_params
 
     except tf.errors.NotFoundError as e:
         print(f"❌ 模型结构缺失关键组件: {str(e)}")

+ 9 - 0
data_processing/data_operation/data_handler.py

@@ -102,6 +102,12 @@ class DataHandler(object):
             data_train = self.data_fill(data_train, col_time)
         return data_train
 
+    def fill_pre_data(self, unite):
+        unite = unite.interpolate(method='linear')  # nwp先进行线性填充
+        unite = unite.fillna(method='ffill')  # 再对超过采样边缘无法填充的点进行二次填充
+        unite = unite.fillna(method='bfill')
+        return unite
+
     def missing_time_splite(self, df, dt_short, dt_long, col_time):
         df.reset_index(drop=True, inplace=True)
         n_long, n_short, n_points = 0, 0, 0
@@ -183,6 +189,7 @@ class DataHandler(object):
         # 对清洗完限电的数据进行特征预处理:
         # 1.空值异常值清洗
         train_data_cleaned = cleaning(train_data, '训练集', self.logger, features + [target], col_time)
+        self.opt.features = [x for x in train_data_cleaned.columns.tolist() if x not in [target, col_time] and x in features]
         # 2. 标准化
         # 创建特征和目标的标准化器
         train_scaler = MinMaxScaler(feature_range=(0, 1))
@@ -228,6 +235,8 @@ class DataHandler(object):
         col_time, features = self.opt.col_time, self.opt.features
         data = data.applymap(lambda x: float(x.to_decimal()) if isinstance(x, Decimal128) else float(x) if isinstance(x, numbers.Number) else x)
         data = data.sort_values(by=col_time).reset_index(drop=True, inplace=False)
+        if self.opt.Model['predict_data_fill']:
+            data = self.fill_pre_data(data)
         pre_data = data[features]
         scaled_features = feature_scaler.transform(data[features])
         pre_data.loc[:, features] = scaled_features