David 6 days ago
parent
commit
40587656be
1 changed files with 8 additions and 8 deletions
  1. 8 8
      data_processing/data_operation/data_handler.py

+ 8 - 8
data_processing/data_operation/data_handler.py

@@ -18,17 +18,17 @@ class DataHandler(object):
         self.logger = logger
         self.opt = argparse.Namespace(**args)
 
-    def get_train_data(self, dfs, col_time, target, time_series=1):
+    def get_train_data(self, dfs, col_time, target, time_series=1, lstm_type=1):
         train_x, valid_x, train_y, valid_y = [], [], [], []
         for i, df in enumerate(dfs, start=1):
             if len(df) < self.opt.Model["time_step"]:
                 self.logger.info("特征处理-训练数据-不满足time_step")
-            if time_series == 2:
+            if lstm_type == 2:
                 datax, datay = self.get_timestep_features_lstm2(df, col_time, target, is_train=True)
-            elif time_series == 3:
+            elif lstm_type == 3:
                 datax, datay = self.get_timestep_features_bilstm(df, col_time, target, is_train=True)
             else:
-                datax, datay = self.get_timestep_features(df, col_time, target, is_train=True)
+                datax, datay = self.get_timestep_features(df, col_time, target, is_train=True, time_series=time_series)
             if len(datax) < 10:
                 self.logger.info("特征处理-训练数据-无法进行最小分割")
                 continue
@@ -74,18 +74,18 @@ class DataHandler(object):
             features_x = np.concatenate((features_x, np.expand_dims(df_repeated, 0)), axis=0)
         return features_x
 
-    def get_timestep_features(self, norm_data, col_time, target, is_train):
+    def get_timestep_features(self, norm_data, col_time, target, is_train, time_series=1):
         """
         步长分割数据,获取时序训练集
         """
         time_step = self.opt.Model["time_step"]
         feature_data = norm_data.reset_index(drop=True)
-        time_step_loc = time_step - 1
+        time_step_loc = time_step*time_series - 1
         train_num = int(len(feature_data))
         label_features = [col_time, target] if is_train is True else [col_time, target]
         nwp_cs = self.opt.features
-        nwp = [feature_data.loc[i:i + time_step_loc, nwp_cs].reset_index(drop=True) for i in range(train_num - time_step + 1)]  # 数据库字段 'C_T': 'C_WS170'
-        labels = [feature_data.loc[i:i + time_step_loc, label_features].reset_index(drop=True) for i in range(train_num - time_step + 1)]
+        nwp = [feature_data.loc[i:i + time_step_loc, nwp_cs].reset_index(drop=True) for i in range(train_num - time_step*time_series + 1)]  # 数据库字段 'C_T': 'C_WS170'
+        labels = [feature_data.loc[i:i + time_step_loc, label_features].reset_index(drop=True) for i in range(train_num - time_step*time_series + 1)]
         features_x, features_y = [], []
         for i, row in enumerate(zip(nwp, labels)):
             features_x.append(row[0])