David 3 ماه پیش
والد
کامیت
649543a223
3فایلهای تغییر یافته به همراه27 افزوده شده و 15 حذف شده
  1. 8 11
      models_processing/model_tf/tf_test.py
  2. 4 1
      models_processing/model_tf/tf_test_pre.py
  3. 15 3
      models_processing/model_tf/tf_test_train.py

+ 8 - 11
models_processing/model_tf/tf_test.py

@@ -24,6 +24,7 @@ class TSHandler(object):
         self.logger = logger
         self.opt = argparse.Namespace(**args)
         self.model = None
+        self.model_params = None
 
     def get_model(self, args):
         """
@@ -32,7 +33,7 @@ class TSHandler(object):
         try:
             with model_lock:
                 loss = region_loss(self.opt)
-                self.model = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
+                self.model, self.model_params = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
         except Exception as e:
             self.logger.info("加载模型权重失败:{}".format(e.args))
 
@@ -120,20 +121,16 @@ class TSHandler(object):
 
     def train_init(self):
         try:
-            if self.opt.Model['add_train']:
-                # 进行加强训练,支持修模
-                loss = region_loss(self.opt)
-                base_train_model = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
-                base_train_model.summary()
-                self.logger.info("已加载加强训练基础模型")
-            else:
-                base_train_model = self.get_keras_model(self.opt)
+            # 进行加强训练,支持修模
+            loss = region_loss(self.opt)
+            base_train_model, self.model_params = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
+            base_train_model.summary()
+            self.logger.info("已加载加强训练基础模型")
             return base_train_model
         except Exception as e:
             self.logger.info("加载模型权重失败:{}".format(e.args))
 
-    def training(self, train_and_valid_data):
-        model = self.train_init()
+    def training(self, model, train_and_valid_data):
         model.summary()
         train_x, train_y, valid_x, valid_y = train_and_valid_data
         # 回调函数配置

+ 4 - 1
models_processing/model_tf/tf_test_pre.py

@@ -53,9 +53,12 @@ def model_prediction_test():
     try:
         pre_data = get_data_from_mongo(args)
         feature_scaler, target_scaler = get_scaler_model_from_mongo(args)
-        scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler)
         ts.opt.cap = round(target_scaler.transform(np.array([[args['cap']]]))[0, 0], 2)
+
         ts.get_model(args)
+        dh.opt.features = json.loads(ts.model_params).get('features', args['features'])
+        scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler)
+
         res = list(chain.from_iterable(target_scaler.inverse_transform(ts.predict(scaled_pre_x))))
         pre_data['farm_id'] = args.get('farm_id', 'null')
         if args.get('algorithm_test', 0):

+ 15 - 3
models_processing/model_tf/tf_test_train.py

@@ -48,11 +48,23 @@ def model_training_test():
         # ------------ 获取数据,预处理训练数据 ------------
         train_data = get_data_from_mongo(args)
         train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data)
-        # ------------ 训练模型,保存模型 ------------
-        ts.opt.Model['input_size'] = train_x.shape[2]
         ts.opt.cap = round(scaled_cap, 2)
-        ts_model = ts.training([train_x, train_y, valid_x, valid_y])
+        ts.opt.Model['input_size'] = len(dh.opt.features)
+        # ------------ 训练模型,保存模型 ------------
+        # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
+        # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
+        model = ts.train_init() if ts.opt.Model['add_train'] else ts.get_keras_model(ts.opt)
+        if ts.opt.Model['add_train'] and model is not False:
+            feas = json.loads(ts.model_params).get('features', args['features'])
+            if set(feas).issubset(set(dh.opt.features)):
+                dh.opt.features = list(feas)
+                train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(
+                    train_data)
+            else:
+                model = ts.get_keras_model(ts.opt)
+                logger.info("训练数据特征,不满足,加强训练模型特征")
 
+        ts_model = ts.training(model, [train_x, train_y, valid_x, valid_y])
         args['params'] = json.dumps(args)
         args['descr'] = '测试'
         args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))