David 3 月之前
父節點
當前提交
b87bb8c301

+ 9 - 7
models_processing/model_tf/tf_bp_train.py

@@ -55,15 +55,17 @@ def model_training_bp():
         # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
         # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
         model = bp.train_init() if bp.opt.Model['add_train'] else bp.get_keras_model(bp.opt)
-        if bp.opt.Model['add_train'] and model is not False:
-            feas = json.loads(bp.model_params).get('features', args['features'])
-            if set(feas).issubset(set(dh.opt.features)):
-                dh.opt.features = list(feas)
-                train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data)
+        if bp.opt.Model['add_train']:
+            if model:
+                feas = json.loads(bp.model_params).get('features', args['features'])
+                if set(feas).issubset(set(dh.opt.features)):
+                    dh.opt.features = list(feas)
+                    train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data)
+                else:
+                    model = bp.get_keras_model(bp.opt)
+                    logger.info("训练数据特征,不满足,加强训练模型特征")
             else:
                 model = bp.get_keras_model(bp.opt)
-                logger.info("训练数据特征,不满足,加强训练模型特征")
-
         bp_model = bp.training(model, [train_x, train_y, valid_x, valid_y])
         # ------------ 保存模型 ------------
         args['params'] = json.dumps(args)

+ 9 - 7
models_processing/model_tf/tf_cnn_train.py

@@ -57,15 +57,17 @@ def model_training_bp():
         # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
         logger.info("---------cap归一化:{}".format(cnn.opt.cap))
         model = cnn.train_init() if cnn.opt.Model['add_train'] else cnn.get_keras_model(cnn.opt)
-        if cnn.opt.Model['add_train'] and model is not False:
-            feas = json.loads(cnn.model_params).get('features', args['features'])
-            if set(feas).issubset(set(dh.opt.features)):
-                dh.opt.features = list(feas)
-                train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data)
+        if cnn.opt.Model['add_train']:
+            if model:
+                feas = json.loads(cnn.model_params).get('features', args['features'])
+                if set(feas).issubset(set(dh.opt.features)):
+                    dh.opt.features = list(feas)
+                    train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data)
+                else:
+                    model = cnn.get_keras_model(cnn.opt)
+                    logger.info("训练数据特征,不满足,加强训练模型特征")
             else:
                 model = cnn.get_keras_model(cnn.opt)
-                logger.info("训练数据特征,不满足,加强训练模型特征")
-
         bp_model = cnn.training(model, [train_x, train_y, valid_x, valid_y])
 
         args['params'] = json.dumps(args)

+ 9 - 6
models_processing/model_tf/tf_lstm_train.py

@@ -54,14 +54,17 @@ def model_training_bp():
         # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
         # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
         model = ts.train_init() if ts.opt.Model['add_train'] else ts.get_keras_model(ts.opt)
-        if ts.opt.Model['add_train'] and model is not False:
-            feas = json.loads(ts.model_params).get('features', args['features'])
-            if set(feas).issubset(set(dh.opt.features)):
-                dh.opt.features = list(feas)
-                train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data)
+        if ts.opt.Model['add_train']:
+            if model:
+                feas = json.loads(ts.model_params).get('features', args['features'])
+                if set(feas).issubset(set(dh.opt.features)):
+                    dh.opt.features = list(feas)
+                    train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data)
+                else:
+                    model = ts.get_keras_model(ts.opt)
+                    logger.info("训练数据特征,不满足,加强训练模型特征")
             else:
                 model = ts.get_keras_model(ts.opt)
-                logger.info("训练数据特征,不满足,加强训练模型特征")
         ts_model = ts.training(model, [train_x, train_y, valid_x, valid_y])
         args['features'] = dh.opt.features
         args['params'] = json.dumps(args)

+ 10 - 8
models_processing/model_tf/tf_test_train.py

@@ -54,16 +54,18 @@ def model_training_test():
         # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
         # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
         model = ts.train_init() if ts.opt.Model['add_train'] else ts.get_keras_model(ts.opt)
-        if ts.opt.Model['add_train'] and model is not False:
-            feas = json.loads(ts.model_params).get('features', args['features'])
-            if set(feas).issubset(set(dh.opt.features)):
-                dh.opt.features = list(feas)
-                train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(
-                    train_data)
+        if ts.opt.Model['add_train']:
+            if model:
+                feas = json.loads(ts.model_params).get('features', args['features'])
+                if set(feas).issubset(set(dh.opt.features)):
+                    dh.opt.features = list(feas)
+                    train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(
+                        train_data)
+                else:
+                    model = ts.get_keras_model(ts.opt)
+                    logger.info("训练数据特征,不满足,加强训练模型特征")
             else:
                 model = ts.get_keras_model(ts.opt)
-                logger.info("训练数据特征,不满足,加强训练模型特征")
-
         ts_model = ts.training(model, [train_x, train_y, valid_x, valid_y])
         args['params'] = json.dumps(args)
         args['descr'] = '测试'