David 1 tháng trước cách đây
mục cha
commit
4a6e8e7f52
1 tập tin đã thay đổi với 10 bổ sung10 xóa
  1. 10 10
      data_processing/data_operation/data_handler.py

+ 10 - 10
data_processing/data_operation/data_handler.py

@@ -18,12 +18,12 @@ class DataHandler(object):
         self.logger = logger
         self.opt = argparse.Namespace(**args)
 
-    def get_train_data(self, dfs, col_time, features, target):
+    def get_train_data(self, dfs, col_time, target):
         train_x, valid_x, train_y, valid_y = [], [], [], []
         for i, df in enumerate(dfs, start=1):
             if len(df) < self.opt.Model["time_step"]:
                 self.logger.info("特征处理-训练数据-不满足time_step")
-            datax, datay = self.get_timestep_features(df, col_time, features, target, is_train=True)
+            datax, datay = self.get_timestep_features(df, col_time, target, is_train=True)
             if len(datax) < 10:
                 self.logger.info("特征处理-训练数据-无法进行最小分割")
                 continue
@@ -41,18 +41,18 @@ class DataHandler(object):
 
         return train_x, valid_x, train_y, valid_y
 
-    def get_predict_data(self, dfs, features):
+    def get_predict_data(self, dfs):
         test_x = []
         for i, df in enumerate(dfs, start=1):
             if len(df) < self.opt.Model["time_step"]:
                 self.logger.info("特征处理-预测数据-不满足time_step")
                 continue
-            datax = self.get_predict_features(df, features)
+            datax = self.get_predict_features(df)
             test_x.append(datax)
         test_x = np.concatenate(test_x, axis=0)
         return test_x
 
-    def get_predict_features(self, norm_data, features):
+    def get_predict_features(self, norm_data):
         """
         均分数据,获取预测数据集
         """
@@ -61,14 +61,14 @@ class DataHandler(object):
         time_step_loc = time_step - 1
         iters = int(len(feature_data)) // self.opt.Model['time_step']
         end = int(len(feature_data)) % self.opt.Model['time_step']
-        features_x = np.array([feature_data.loc[i*time_step:i*time_step + time_step_loc, features].reset_index(drop=True) for i in range(iters)])
+        features_x = np.array([feature_data.loc[i*time_step:i*time_step + time_step_loc, self.opt.features].reset_index(drop=True) for i in range(iters)])
         if end > 0:
             df = feature_data.tail(end)
             df_repeated = pd.concat([df] + [pd.DataFrame([df.iloc[0]]* (time_step-end))]).reset_index(drop=True)
             features_x = np.concatenate((features_x, np.expand_dims(df_repeated, 0)), axis=0)
         return features_x
 
-    def get_timestep_features(self, norm_data, col_time, features, target, is_train):
+    def get_timestep_features(self, norm_data, col_time, target, is_train):
         """
         步长分割数据,获取时序训练集
         """
@@ -77,7 +77,7 @@ class DataHandler(object):
         time_step_loc = time_step - 1
         train_num = int(len(feature_data))
         label_features = [col_time, target] if is_train is True else [col_time, target]
-        nwp_cs = features
+        nwp_cs = self.opt.features
         nwp = [feature_data.loc[i:i + time_step_loc, nwp_cs].reset_index(drop=True) for i in range(train_num - time_step + 1)]  # 数据库字段 'C_T': 'C_WS170'
         labels = [feature_data.loc[i:i + time_step_loc, label_features].reset_index(drop=True) for i in range(train_num - time_step + 1)]
         features_x, features_y = [], []
@@ -214,7 +214,7 @@ class DataHandler(object):
             train_x, valid_x, train_y, valid_y = self.train_valid_split(train_data[self.opt.features].values, train_data[target].values, valid_rate=self.opt.Model["valid_data_rate"], shuffle=self.opt.Model['shuffle_train_data'])
             train_x, valid_x, train_y, valid_y =  np.array(train_x), np.array(valid_x), np.array(train_y), np.array(valid_y)
         else:
-            train_x, valid_x, train_y, valid_y = self.get_train_data(train_datas, col_time, self.opt.features, target)
+            train_x, valid_x, train_y, valid_y = self.get_train_data(train_datas, col_time, target)
         return train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap
 
     def pre_data_handler(self, data, feature_scaler, bp_data=False):
@@ -241,5 +241,5 @@ class DataHandler(object):
         if bp_data:
             pre_x = np.array(pre_data)
         else:
-            pre_x = self.get_predict_data([pre_data], features)
+            pre_x = self.get_predict_data([pre_data])
         return pre_x, data