David před 3 týdny
rodič
revize
848247e5ee

+ 48 - 5
data_processing/data_operation/data_handler.py

@@ -18,12 +18,17 @@ class DataHandler(object):
         self.logger = logger
         self.opt = argparse.Namespace(**args)
 
-    def get_train_data(self, dfs, col_time, target):
+    def get_train_data(self, dfs, col_time, target, time_series=1):
         train_x, valid_x, train_y, valid_y = [], [], [], []
         for i, df in enumerate(dfs, start=1):
             if len(df) < self.opt.Model["time_step"]:
                 self.logger.info("特征处理-训练数据-不满足time_step")
-            datax, datay = self.get_timestep_features(df, col_time, target, is_train=True)
+            if time_series == 2:
+                datax, datay = self.get_timestep_features_lstm2(df, col_time, target, is_train=True)
+            elif time_series == 3:
+                datax, datay = self.get_timestep_features_bilstm(df, col_time, target, is_train=True)
+            else:
+                datax, datay = self.get_timestep_features(df, col_time, target, is_train=True)
             if len(datax) < 10:
                 self.logger.info("特征处理-训练数据-无法进行最小分割")
                 continue
@@ -41,23 +46,24 @@ class DataHandler(object):
 
         return train_x, valid_x, train_y, valid_y
 
-    def get_predict_data(self, dfs):
+    def get_predict_data(self, dfs, time_series=1):
         test_x = []
         for i, df in enumerate(dfs, start=1):
             if len(df) < self.opt.Model["time_step"]:
                 self.logger.info("特征处理-预测数据-不满足time_step")
                 continue
-            datax = self.get_predict_features(df)
+            datax = self.get_predict_features(df, time_series)
             test_x.append(datax)
         test_x = np.concatenate(test_x, axis=0)
         return test_x
 
-    def get_predict_features(self, norm_data):
+    def get_predict_features(self, norm_data, time_series=1):
         """
         均分数据,获取预测数据集
         """
         time_step = self.opt.Model["time_step"]
         feature_data = norm_data.reset_index(drop=True)
+        time_step *= int(time_series)
         time_step_loc = time_step - 1
         iters = int(len(feature_data)) // self.opt.Model['time_step']
         end = int(len(feature_data)) % self.opt.Model['time_step']
@@ -86,6 +92,43 @@ class DataHandler(object):
             features_y.append(row[1])
         return features_x, features_y
 
+    def get_timestep_features_lstm2(self, norm_data, col_time, target, is_train):
+        """
+        步长分割数据,获取时序训练集
+        """
+        time_step = self.opt.Model["time_step"]
+        feature_data = norm_data.reset_index(drop=True)
+        time_step_loc = time_step*2 - 1
+        train_num = int(len(feature_data))
+        label_features = [col_time, target] if is_train is True else [col_time, target]
+        nwp_cs = self.opt.features
+        nwp = [feature_data.loc[i:i + time_step_loc, nwp_cs].reset_index(drop=True) for i in range(train_num - time_step*2 + 1)]  # 数据库字段 'C_T': 'C_WS170'
+        labels = [feature_data.loc[i+time_step:i+time_step_loc, label_features].reset_index(drop=True) for i in range(train_num - time_step*2 + 1)]
+        features_x, features_y = [], []
+        for i, row in enumerate(zip(nwp, labels)):
+            features_x.append(row[0])
+            features_y.append(row[1])
+        return features_x, features_y
+
+    def get_timestep_features_bilstm(self, norm_data, col_time, target, is_train):
+        """
+        步长分割数据,获取时序训练集
+        """
+        time_step = self.opt.Model["time_step"]
+        feature_data = norm_data.reset_index(drop=True)
+        time_step_loc = time_step*3 - 1
+        time_step_m = time_step*2 - 1
+        train_num = int(len(feature_data))
+        label_features = [col_time, target] if is_train is True else [col_time, target]
+        nwp_cs = self.opt.features
+        nwp = [feature_data.loc[i:i + time_step_loc, nwp_cs].reset_index(drop=True) for i in range(train_num - time_step*3 + 1)]  # 数据库字段 'C_T': 'C_WS170'
+        labels = [feature_data.loc[i+time_step:i+time_step_m, label_features].reset_index(drop=True) for i in range(train_num - time_step*3 + 1)]
+        features_x, features_y = [], []
+        for i, row in enumerate(zip(nwp, labels)):
+            features_x.append(row[0])
+            features_y.append(row[1])
+        return features_x, features_y
+
     def fill_train_data(self, unite, col_time):
         """
         补值

+ 0 - 23
models_processing/model_tf/test.py

@@ -1,23 +0,0 @@
-#!/usr/bin/env python
-# -*- coding:utf-8 -*-
-# @FileName  :test.py
-# @Time      :2025/3/25 09:16
-# @Author    :David
-# @Company: shenyang JY
-
-{"features": ["temperature190", "temperature10", "direction160", "direction40", "temperature110",
-              "speed60", "direction80", "mcc", "temperature150", "speed20", "speed110",
-              "globalr", "solarZenith", "speed190", "direction120", "direction200",
-              "temperature90", "speed150", "temperature50", "direction30",
-              "temperature160", "direction170", "temperature20",
-              "direction70", "direction130", "temperature200", "speed70", "temperature120",
-              "speed30", "speed100", "speed80", "speed180", "dniCalcd", "speed140",
-              "temperature60", "temperature170", "temperature30", "direction20",
-              "humidity2", "direction180", "direction60", "direction140", "hcc", "speed40",
-              "clearskyGhi", "temperature130", "lcc", "speed90", "tcc", "temperature2",
-              "speed170", "direction100", "temperature70", "speed130", "direction190", "temperature40",
-              "direction10", "temperature180", "direction150", "direction50", "speed50", "direction90",
-              "temperature100", "speed10", "temperature140", "speed120", "speed200", "radiation", "tpr",
-              "surfacePressure", "direction110", "speed160", "temperature80"]}
-if __name__ == "__main__":
-    run_code = 0

+ 10 - 0
models_processing/model_tf/tf_bilstm.py

@@ -0,0 +1,10 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# @FileName  :tf_bilstm.py.py
+# @Time      :2025/4/14 15:43
+# @Author    :David
+# @Company: shenyang JY
+ 
+ 
+if __name__ == "__main__":
+    run_code = 0