David 2 mesi fa
parent
commit
516388da2f
1 ha cambiato i file con 4 aggiunte e 4 eliminazioni
  1. 4 4
      data_processing/data_operation/data_handler.py

+ 4 - 4
data_processing/data_operation/data_handler.py

@@ -205,7 +205,7 @@ class DataHandler(object):
                 vy.append(data[1])
         return tx, vx, ty, vy
 
-    def train_data_handler(self, data, bp_data=False):
+    def train_data_handler(self, data, bp_data=False, time_series=1):
         """
         训练数据预处理:
         清洗+补值+归一化
@@ -257,10 +257,10 @@ class DataHandler(object):
             train_x, valid_x, train_y, valid_y = self.train_valid_split(train_data[self.opt.features].values, train_data[target].values, valid_rate=self.opt.Model["valid_data_rate"], shuffle=self.opt.Model['shuffle_train_data'])
             train_x, valid_x, train_y, valid_y =  np.array(train_x), np.array(valid_x), np.array(train_y), np.array(valid_y)
         else:
-            train_x, valid_x, train_y, valid_y = self.get_train_data(train_datas, col_time, target)
+            train_x, valid_x, train_y, valid_y = self.get_train_data(train_datas, col_time, target, time_series)
         return train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap
 
-    def pre_data_handler(self, data, feature_scaler, bp_data=False):
+    def pre_data_handler(self, data, feature_scaler, bp_data=False, time_series=1):
         """
         预测数据简单处理
         Args:
@@ -286,5 +286,5 @@ class DataHandler(object):
         if bp_data:
             pre_x = np.array(pre_data)
         else:
-            pre_x = self.get_predict_data([pre_data])
+            pre_x = self.get_predict_data([pre_data], time_series)
         return pre_x, data