|
@@ -24,9 +24,8 @@ app = Flask('tf_bp_train——service')
|
|
|
|
|
|
with app.app_context():
|
|
|
with open('../model_koi/bp.yaml', 'r', encoding='utf-8') as f:
|
|
|
- arguments = yaml.safe_load(f)
|
|
|
-
|
|
|
- dh = DataHandler(logger, arguments)
|
|
|
+ args = yaml.safe_load(f)
|
|
|
+ dh = DataHandler(logger, args)
|
|
|
bp = BPHandler(logger)
|
|
|
|
|
|
@app.route('/nn_bp_training', methods=['POST'])
|
|
@@ -36,15 +35,17 @@ def model_training_bp():
|
|
|
result = {}
|
|
|
success = 0
|
|
|
print("Program starts execution!")
|
|
|
+ # ------------ 整理参数,整合请求参数 ------------
|
|
|
args_dict = request.values.to_dict()
|
|
|
args_dict['features'] = args_dict['features'].split(',')
|
|
|
- args = copy.deepcopy(arguments)
|
|
|
args.update(args_dict)
|
|
|
- # try:
|
|
|
opt = argparse.Namespace(**args)
|
|
|
logger.info(args_dict)
|
|
|
- train_data = get_data_from_mongo(args_dict)
|
|
|
+ # try:
|
|
|
+ # ------------ 获取数据,预处理训练数据 ------------
|
|
|
+ train_data = get_data_from_mongo(args)
|
|
|
train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(train_data, opt, bp_data=True)
|
|
|
+ # ------------ 训练模型,保存模型 ------------
|
|
|
opt.Model['input_size'] = train_x.shape[1]
|
|
|
bp_model = bp.training(opt, [train_x, train_y, valid_x, valid_y])
|
|
|
args_dict['params'] = json.dumps(args)
|