David 1 kuukausi sitten
vanhempi
commit
4c7949a9cb
2 muutettua tiedostoa jossa 7 lisäystä ja 2 poistoa
  1. 1 1
      app/model/main.py
  2. 6 1
      app/predict/main.py

+ 1 - 1
app/model/main.py

@@ -71,7 +71,7 @@ def input_file_handler(input_file: str):
             if args['model_name'] == 'fmi':
                 from tf_fmi_train import model_training
             elif args['model_name'] == 'cnn':
-                from tf_lstm_train import model_training
+                from tf_cnn_train import model_training
             else:
                 from tf_lstm_train import model_training
             model_training(train_data, input_file, cap)

+ 6 - 1
app/predict/main.py

@@ -11,7 +11,6 @@
 import argparse
 import pandas as pd
 from pathlib import Path
-from app.predict.tf_lstm_pre import model_prediction
 from app.common.logs import logger, args
 
 """
@@ -69,6 +68,12 @@ def input_file_handler(input_file: str):
         # 含有predict,预测
         if 'predict' in input_file.lower():
             pre_data = nwp_v
+            if args['model_name'] == 'fmi':
+                from tf_fmi_pre import model_prediction
+            elif args['model_name'] == 'cnn':
+                from tf_cnn_pre import model_prediction
+            else:
+                from tf_lstm_pre import model_prediction
             model_prediction(pre_data, input_file, cap)
         else:
             logger.info("预测路径错误!")