David 2 mesiacov pred
rodič
commit
0cfb1f9901

+ 11 - 9
models_processing/model_koi/tf_bp.py

@@ -13,11 +13,13 @@ from models_processing.losses.loss_cdq import rmse
 import numpy as np
 from common.database_dml import *
 from threading import Lock
+import argparse
 model_lock = Lock()
 
 class BPHandler(object):
-    def __init__(self, logger):
+    def __init__(self, logger, args):
         self.logger = logger
+        self.opt = argparse.Namespace(**args)
         self.model = None
 
     def get_model(self, args):
@@ -42,28 +44,28 @@ class BPHandler(object):
         model.compile(loss=rmse, optimizer=adam)
         return model
 
-    def train_init(self, opt):
+    def train_init(self):
         try:
-            if opt.Model['add_train']:
+            if self.opt.Model['add_train']:
                 # 进行加强训练,支持修模
-                base_train_model = get_h5_model_from_mongo(vars(opt), {'rmse': rmse})
+                base_train_model = get_h5_model_from_mongo(vars(self.opt), {'rmse': rmse})
                 base_train_model.summary()
                 self.logger.info("已加载加强训练基础模型")
             else:
-                base_train_model = self.get_keras_model(opt)
+                base_train_model = self.get_keras_model(self.opt)
             return base_train_model
         except Exception as e:
             self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
 
-    def training(self, opt, train_and_valid_data):
-        model = self.train_init(opt)
+    def training(self, train_and_valid_data):
+        model = self.train_init()
         # tf.reset_default_graph() # 清除默认图
         train_x, train_y, valid_x, valid_y = train_and_valid_data
         print("----------", np.array(train_x[0]).shape)
         print("++++++++++", np.array(train_x[1]).shape)
         model.summary()
-        early_stop = EarlyStopping(monitor='val_loss', patience=opt.Model['patience'], mode='auto')
-        history = model.fit(train_x, train_y, batch_size=opt.Model['batch_size'], epochs=opt.Model['epoch'], verbose=2,  validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
+        early_stop = EarlyStopping(monitor='val_loss', patience=self.opt.Model['patience'], mode='auto')
+        history = model.fit(train_x, train_y, batch_size=self.opt.Model['batch_size'], epochs=self.opt.Model['epoch'], verbose=2,  validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
         loss = np.round(history.history['loss'], decimals=5)
         val_loss = np.round(history.history['val_loss'], decimals=5)
         self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))

+ 13 - 8
models_processing/model_koi/tf_bp_pre.py

@@ -27,8 +27,19 @@ with app.app_context():
     with open('../model_koi/bp.yaml', 'r', encoding='utf-8') as f:
         args = yaml.safe_load(f)
     dh = DataHandler(logger, args)
-    bp = BPHandler(logger)
+    bp = BPHandler(logger, args)
+    global opt
 
+@app.before_request
+def update_config():
+    # ------------ 整理参数,整合请求参数 ------------
+    args_dict = request.values.to_dict()
+    args_dict['features'] = args_dict['features'].split(',')
+    args.update(args_dict)
+    opt = argparse.Namespace(**args)
+    dh.opt = opt
+    bp.opt = opt
+    logger.info(args)
 
 @app.route('/nn_bp_predict', methods=['POST'])
 def model_prediction_bp():
@@ -37,17 +48,11 @@ def model_prediction_bp():
     result = {}
     success = 0
     print("Program starts execution!")
-    # ------------ 整理参数,整合请求参数 ------------
-    args_dict = request.values.to_dict()
-    args_dict['features'] = args_dict['features'].split(',')
-    args.update(args_dict)
-    opt = argparse.Namespace(**args)
-    logger.info(args)
     try:
         # ------------ 获取数据,预处理预测数据------------
         pre_data = get_data_from_mongo(args)
         feature_scaler, target_scaler = get_scaler_model_from_mongo(args)
-        scaled_pre_x = dh.pre_data_handler(pre_data, feature_scaler, opt, bp_data=True)
+        scaled_pre_x = dh.pre_data_handler(pre_data, feature_scaler, bp_data=True)
         bp.get_model(args)
         res = list(chain.from_iterable(target_scaler.inverse_transform([bp.predict(scaled_pre_x).flatten()])))
         pre_data['power_forecast'] = res[:len(pre_data)]

+ 15 - 9
models_processing/model_koi/tf_bp_train.py

@@ -27,6 +27,18 @@ with app.app_context():
         args = yaml.safe_load(f)
     dh = DataHandler(logger, args)
     bp = BPHandler(logger)
+    global opt
+
+@app.before_request
+def update_config():
+    # ------------ 整理参数,整合请求参数 ------------
+    args_dict = request.values.to_dict()
+    args_dict['features'] = args_dict['features'].split(',')
+    args.update(args_dict)
+    opt = argparse.Namespace(**args)
+    dh.opt = opt
+    bp.opt = opt
+    logger.info(args)
 
 @app.route('/nn_bp_training', methods=['POST'])
 def model_training_bp():
@@ -35,19 +47,13 @@ def model_training_bp():
     result = {}
     success = 0
     print("Program starts execution!")
-    # ------------ 整理参数,整合请求参数 ------------
-    args_dict = request.values.to_dict()
-    args_dict['features'] = args_dict['features'].split(',')
-    args.update(args_dict)
-    opt = argparse.Namespace(**args)
-    logger.info(args_dict)
     # try:
     # ------------ 获取数据,预处理训练数据 ------------
     train_data = get_data_from_mongo(args)
-    train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(train_data, opt, bp_data=True)
+    train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(train_data, bp_data=True)
     # ------------ 训练模型,保存模型 ------------
-    opt.Model['input_size'] = train_x.shape[1]
-    bp_model = bp.training(opt, [train_x, train_y, valid_x, valid_y])
+    bp.opt.Model['input_size'] = train_x.shape[1]
+    bp_model = bp.training([train_x, train_y, valid_x, valid_y])
     args['params'] = json.dumps(args)
     args['descr'] = '测试'
     args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))