David hai 4 meses
pai
achega
ba739e7d48
Modificáronse 1 ficheiros con 7 adicións e 6 borrados
  1. 7 6
      models_processing/model_koi/tf_bp_train.py

+ 7 - 6
models_processing/model_koi/tf_bp_train.py

@@ -24,9 +24,8 @@ app = Flask('tf_bp_train——service')
 
 with app.app_context():
     with open('../model_koi/bp.yaml', 'r', encoding='utf-8') as f:
-        arguments = yaml.safe_load(f)
-
-    dh = DataHandler(logger, arguments)
+        args = yaml.safe_load(f)
+    dh = DataHandler(logger, args)
     bp = BPHandler(logger)
 
 @app.route('/nn_bp_training', methods=['POST'])
@@ -36,15 +35,17 @@ def model_training_bp():
     result = {}
     success = 0
     print("Program starts execution!")
+    # ------------ 整理参数,整合请求参数 ------------
     args_dict = request.values.to_dict()
     args_dict['features'] = args_dict['features'].split(',')
-    args = copy.deepcopy(arguments)
     args.update(args_dict)
-    # try:
     opt = argparse.Namespace(**args)
     logger.info(args_dict)
-    train_data = get_data_from_mongo(args_dict)
+    # try:
+    # ------------ 获取数据,预处理训练数据 ------------
+    train_data = get_data_from_mongo(args)
     train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(train_data, opt, bp_data=True)
+    # ------------ 训练模型,保存模型 ------------
     opt.Model['input_size'] = train_x.shape[1]
     bp_model = bp.training(opt, [train_x, train_y, valid_x, valid_y])
     args_dict['params'] = json.dumps(args)