|
@@ -11,7 +11,6 @@
|
|
|
import argparse
|
|
|
import pandas as pd
|
|
|
from pathlib import Path
|
|
|
-from app.predict.tf_lstm_pre import model_prediction
|
|
|
from app.common.logs import logger, args
|
|
|
|
|
|
"""
|
|
@@ -69,6 +68,12 @@ def input_file_handler(input_file: str):
|
|
|
# 含有predict,预测
|
|
|
if 'predict' in input_file.lower():
|
|
|
pre_data = nwp_v
|
|
|
+ if args['model_name'] == 'fmi':
|
|
|
+ from tf_fmi_pre import model_prediction
|
|
|
+ elif args['model_name'] == 'cnn':
|
|
|
+ from tf_cnn_pre import model_prediction
|
|
|
+ else:
|
|
|
+ from tf_lstm_pre import model_prediction
|
|
|
model_prediction(pre_data, input_file, cap)
|
|
|
else:
|
|
|
logger.info("预测路径错误!")
|