Browse Source

Merge branch 'dev_david' of anweiguo/algorithm_platform into dev_awg

liudawei 1 month ago
parent
commit
4942195932

+ 3 - 2
models_processing/model_tf/tf_bp_pre.py

@@ -34,7 +34,8 @@ with app.app_context():
 def update_config():
     # ------------ 整理参数,整合请求参数 ------------
     args_dict = request.values.to_dict()
-    args_dict['features'] = args_dict['features'].split(',')
+    if 'features' in args_dict:
+        args_dict['features'] = args_dict['features'].split(',')
     args.update(args_dict)
     opt = argparse.Namespace(**args)
     dh.opt = opt
@@ -59,7 +60,7 @@ def model_prediction_bp():
         bp.opt.cap = round(target_scaler.transform(np.array([[float(args['cap'])]]))[0, 0], 2)
         # ------------ 获取模型,预测结果------------
         bp.get_model(args)
-        dh.opt.features = json.loads(bp.model_params).get('features', args['features'])
+        dh.opt.features = json.loads(bp.model_params).get('features', bp.opt.features)
         scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler, bp_data=True)
 
         res = list(chain.from_iterable(target_scaler.inverse_transform(bp.predict(scaled_pre_x))))

+ 3 - 2
models_processing/model_tf/tf_cnn_pre.py

@@ -35,7 +35,8 @@ with app.app_context():
 def update_config():
     # ------------ 整理参数,整合请求参数 ------------
     args_dict = request.values.to_dict()
-    args_dict['features'] = args_dict['features'].split(',')
+    if 'features' in args_dict:
+        args_dict['features'] = args_dict['features'].split(',')
     args.update(args_dict)
     opt = argparse.Namespace(**args)
     dh.opt = opt
@@ -59,7 +60,7 @@ def model_prediction_bp():
         cnn.opt.cap = round(target_scaler.transform(np.array([[float(args['cap'])]]))[0, 0], 2)
 
         cnn.get_model(args)
-        dh.opt.features = json.loads(cnn.model_params).get('features', args['features'])
+        dh.opt.features = json.loads(cnn.model_params).get('features', cnn.opt.features)
         scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler)
         logger.info("---------cap归一化:{}".format(cnn.opt.cap))
 

+ 3 - 2
models_processing/model_tf/tf_lstm_pre.py

@@ -35,7 +35,8 @@ with app.app_context():
 def update_config():
     # ------------ 整理参数,整合请求参数 ------------
     args_dict = request.values.to_dict()
-    args_dict['features'] = args_dict['features'].split(',')
+    if 'features' in args_dict:
+        args_dict['features'] = args_dict['features'].split(',')
     args.update(args_dict)
     opt = argparse.Namespace(**args)
     dh.opt = opt
@@ -58,7 +59,7 @@ def model_prediction_bp():
         feature_scaler, target_scaler = get_scaler_model_from_mongo(args)
         ts.opt.cap = round(target_scaler.transform(np.array([[float(args['cap'])]]))[0, 0], 2)
         ts.get_model(args)
-        dh.opt.features = json.loads(ts.model_params).get('features', args['features'])
+        dh.opt.features = json.loads(ts.model_params).get('features', ts.opt.features)
         scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler)
         res = list(chain.from_iterable(target_scaler.inverse_transform(ts.predict(scaled_pre_x))))
         pre_data['farm_id'] = args.get('farm_id', 'null')

+ 3 - 2
models_processing/model_tf/tf_test_pre.py

@@ -35,7 +35,8 @@ with app.app_context():
 def update_config():
     # ------------ 整理参数,整合请求参数 ------------
     args_dict = request.values.to_dict()
-    args_dict['features'] = args_dict['features'].split(',')
+    if 'features' in args_dict:
+        args_dict['features'] = args_dict['features'].split(',')
     args.update(args_dict)
     opt = argparse.Namespace(**args)
     dh.opt = opt
@@ -59,7 +60,7 @@ def model_prediction_test():
         ts.opt.cap = round(target_scaler.transform(np.array([[float(args['cap'])]]))[0, 0], 2)
 
         ts.get_model(args)
-        dh.opt.features = json.loads(ts.model_params).get('features', args['features'])
+        dh.opt.features = json.loads(ts.model_params).get('features', ts.opt.features)
         scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler)
 
         res = list(chain.from_iterable(target_scaler.inverse_transform(ts.predict(scaled_pre_x))))