|
@@ -69,11 +69,11 @@ def input_file_handler(input_file: str):
|
|
|
if 'model' in input_file.lower():
|
|
|
train_data = pd.merge(nwp_v_h, power, on='Datetime')
|
|
|
if args['model_name'] == 'fmi':
|
|
|
- from tf_fmi_train import model_training
|
|
|
+ from app.model.tf_fmi_train import model_training
|
|
|
elif args['model_name'] == 'cnn':
|
|
|
- from tf_cnn_train import model_training
|
|
|
+ from app.model.tf_cnn_train import model_training
|
|
|
else:
|
|
|
- from tf_lstm_train import model_training
|
|
|
+ from app.model.tf_lstm_train import model_training
|
|
|
model_training(train_data, input_file, cap)
|
|
|
# 含有predict,预测
|
|
|
else:
|