David 1 month ago
parent
commit
8ad4ad7139
2 changed files with 6 additions and 6 deletions
  1. 3 3
      app/model/main.py
  2. 3 3
      app/predict/main.py

+ 3 - 3
app/model/main.py

@@ -69,11 +69,11 @@ def input_file_handler(input_file: str):
         if 'model' in input_file.lower():
             train_data = pd.merge(nwp_v_h, power, on='Datetime')
             if args['model_name'] == 'fmi':
-                from tf_fmi_train import model_training
+                from app.model.tf_fmi_train import model_training
             elif args['model_name'] == 'cnn':
-                from tf_cnn_train import model_training
+                from app.model.tf_cnn_train import model_training
             else:
-                from tf_lstm_train import model_training
+                from app.model.tf_lstm_train import model_training
             model_training(train_data, input_file, cap)
         # 含有predict,预测
         else:

+ 3 - 3
app/predict/main.py

@@ -69,11 +69,11 @@ def input_file_handler(input_file: str):
         if 'predict' in input_file.lower():
             pre_data = nwp_v
             if args['model_name'] == 'fmi':
-                from tf_fmi_pre import model_prediction
+                from app.predict.tf_fmi_pre import model_prediction
             elif args['model_name'] == 'cnn':
-                from tf_cnn_pre import model_prediction
+                from app.predict.tf_cnn_pre import model_prediction
             else:
-                from tf_lstm_pre import model_prediction
+                from app.predict.tf_lstm_pre import model_prediction
             model_prediction(pre_data, input_file, cap)
         else:
             logger.info("预测路径错误!")