David 1 月之前
父節點
當前提交
8ad4ad7139
共有 2 個文件被更改,包括 6 次插入6 次删除
  1. 3 3
      app/model/main.py
  2. 3 3
      app/predict/main.py

+ 3 - 3
app/model/main.py

@@ -69,11 +69,11 @@ def input_file_handler(input_file: str):
         if 'model' in input_file.lower():
             train_data = pd.merge(nwp_v_h, power, on='Datetime')
             if args['model_name'] == 'fmi':
-                from tf_fmi_train import model_training
+                from app.model.tf_fmi_train import model_training
             elif args['model_name'] == 'cnn':
-                from tf_cnn_train import model_training
+                from app.model.tf_cnn_train import model_training
             else:
-                from tf_lstm_train import model_training
+                from app.model.tf_lstm_train import model_training
             model_training(train_data, input_file, cap)
         # 含有predict,预测
         else:

+ 3 - 3
app/predict/main.py

@@ -69,11 +69,11 @@ def input_file_handler(input_file: str):
         if 'predict' in input_file.lower():
             pre_data = nwp_v
             if args['model_name'] == 'fmi':
-                from tf_fmi_pre import model_prediction
+                from app.predict.tf_fmi_pre import model_prediction
             elif args['model_name'] == 'cnn':
-                from tf_cnn_pre import model_prediction
+                from app.predict.tf_cnn_pre import model_prediction
             else:
-                from tf_lstm_pre import model_prediction
+                from app.predict.tf_lstm_pre import model_prediction
             model_prediction(pre_data, input_file, cap)
         else:
             logger.info("预测路径错误!")