浏览代码

Merge branch 'dev_david' of anweiguo/algorithm_platform into dev_awg

liudawei 3 月之前
父节点
当前提交
6bdf37515f

+ 7 - 5
models_processing/model_koi/tf_bp_pre.py

@@ -59,16 +59,18 @@ def model_prediction_bp():
         res = list(chain.from_iterable(target_scaler.inverse_transform([bp.predict(scaled_pre_x).flatten()])))
         pre_data['power_forecast'] = res[:len(pre_data)]
         pre_data['farm_id'] = 'J00083'
-        pre_data['cdq'] = args.get('cdq', 1)
-        pre_data['dq'] = args.get('dq', 1)
-        pre_data['zq'] = args.get('zq', 1)
-        res_cols = ['date_time', 'power_forecast', 'farm_id', 'cdq', 'dq', 'zq']
+
         if args.get('algorithm_test', 0):
             pre_data['model'] = 'lstm'
-            res_cols += [args['target'], 'model']
             pre_data.rename(columns={args['col_time']: 'dateTime'}, inplace=True)
+            res_cols = ['dateTime', 'power_forecast', 'farm_id', args['target'], 'model']
         else:
+            pre_data['cdq'] = args.get('cdq', 1)
+            pre_data['dq'] = args.get('dq', 1)
+            pre_data['zq'] = args.get('zq', 1)
             pre_data.rename(columns={args['col_time']: 'date_time'}, inplace=True)
+            res_cols = ['date_time', 'power_forecast', 'farm_id', 'cdq', 'dq', 'zq']
+
         pre_data = pre_data[res_cols]
 
         pre_data['power_forecast'] = pre_data['power_forecast'].round(2)

+ 5 - 5
models_processing/model_koi/tf_cnn_pre.py

@@ -60,16 +60,16 @@ def model_prediction_bp():
         res = list(chain.from_iterable(target_scaler.inverse_transform([cnn.predict(scaled_pre_x).flatten()])))
         pre_data['power_forecast'] = res[:len(pre_data)]
         pre_data['farm_id'] = 'J00083'
-        pre_data['cdq'] = args.get('cdq', 1)
-        pre_data['dq'] = args.get('dq', 1)
-        pre_data['zq'] = args.get('zq', 1)
-        res_cols = ['date_time', 'power_forecast', 'farm_id', 'cdq', 'dq', 'zq']
         if args.get('algorithm_test', 0):
             pre_data['model'] = 'cnn'
-            res_cols += [args['target'], 'model']
             pre_data.rename(columns={args['col_time']: 'dateTime'}, inplace=True)
+            res_cols = ['dateTime', 'power_forecast', 'farm_id', args['target'], 'model']
         else:
+            pre_data['cdq'] = args.get('cdq', 1)
+            pre_data['dq'] = args.get('dq', 1)
+            pre_data['zq'] = args.get('zq', 1)
             pre_data.rename(columns={args['col_time']: 'date_time'}, inplace=True)
+            res_cols = ['date_time', 'power_forecast', 'farm_id', 'cdq', 'dq', 'zq']
         pre_data = pre_data[res_cols]
 
         pre_data['power_forecast'] = pre_data['power_forecast'].round(2)

+ 5 - 5
models_processing/model_koi/tf_lstm_pre.py

@@ -59,16 +59,16 @@ def model_prediction_bp():
         res = list(chain.from_iterable(target_scaler.inverse_transform([ts.predict(scaled_pre_x).flatten()])))
         pre_data['power_forecast'] = res[:len(pre_data)]
         pre_data['farm_id'] = 'J00083'
-        pre_data['cdq'] = args.get('cdq', 1)
-        pre_data['dq'] = args.get('dq', 1)
-        pre_data['zq'] = args.get('zq', 1)
-        res_cols = ['date_time', 'power_forecast', 'farm_id', 'cdq', 'dq', 'zq']
         if args.get('algorithm_test', 0):
             pre_data['model'] = 'lstm'
-            res_cols += [args['target'], 'model']
             pre_data.rename(columns={args['col_time']: 'dateTime'}, inplace=True)
+            res_cols = ['dateTime', 'power_forecast', 'farm_id', 'cdq', 'dq', 'zq', args['target'], 'model']
         else:
+            pre_data['cdq'] = args.get('cdq', 1)
+            pre_data['dq'] = args.get('dq', 1)
+            pre_data['zq'] = args.get('zq', 1)
             pre_data.rename(columns={args['col_time']: 'date_time'}, inplace=True)
+            res_cols = ['date_time', 'power_forecast', 'farm_id', 'cdq', 'dq', 'zq']
         pre_data = pre_data[res_cols]
 
         pre_data['power_forecast'] = pre_data['power_forecast'].round(2)