浏览代码

深度替换参数

David 1 周之前
父节点
当前提交
7c00854a68

+ 29 - 0
common/data_utils.py

@@ -0,0 +1,29 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# @FileName  :data_utils.py
+# @Time      :2025/5/21 16:16
+# @Author    :David
+# @Company: shenyang JY
+
+
+
+def deep_update(target, source):
+    """
+    递归将 source 字典的内容合并到 target 字典中
+    规则:
+      1. 若 key 在 target 和 source 中都存在,且值均为字典 → 递归合并
+      2. 若 key 在 source 中存在但 target 不存在 → 直接添加
+      3. 若 key 在 source 中存在且类型不为字典 → 覆盖 target 的值
+    """
+    for key, value in source.items():
+        # 如果 target 中存在该 key 且双方值都是字典 → 递归合并
+        if key in target and isinstance(target[key], dict) and isinstance(value, dict):
+            deep_update(target[key], value)
+        else:
+            # 直接覆盖或添加(包括非字典类型或 target 中不存在该 key 的情况)
+            target[key] = value
+    return target
+
+
+if __name__ == "__main__":
+    run_code = 0

+ 1 - 1
models_processing/model_tf/lstm.yaml

@@ -2,7 +2,7 @@ Model:
   add_train: true
   batch_size: 64
   dropout_rate: 0.2
-  epoch: 200
+  epoch: 500
   fusion: true
   hidden_size: 64
   his_points: 16

+ 2 - 1
models_processing/model_tf/tf_bp_pre.py

@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 from tf_bp import BPHandler
 logger = Log('tf_bp').logger
 np.random.seed(42)  # NumPy随机种子
@@ -35,7 +36,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 2 - 1
models_processing/model_tf/tf_bp_train.py

@@ -17,6 +17,7 @@ from models_processing.model_tf.tf_bp import BPHandler
 from common.database_dml_koi import *
 import matplotlib.pyplot as plt
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_bp').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_bp_train——service')
@@ -34,7 +35,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 2 - 1
models_processing/model_tf/tf_cnn_pre.py

@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 from tf_cnn import CNNHandler
 # logger = Log('tf_bp').logger()
 logger = Log('tf_cnn').logger
@@ -36,7 +37,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 2 - 1
models_processing/model_tf/tf_cnn_train.py

@@ -16,6 +16,7 @@ from models_processing.model_tf.tf_cnn import CNNHandler
 from common.database_dml_koi import *
 import matplotlib.pyplot as plt
 from common.logs import Log
+from common.data_utils import deep_update
 # logger = logging.getLogger()
 logger = Log('tf_cnn').logger
 np.random.seed(42)  # NumPy随机种子
@@ -33,7 +34,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 2 - 2
models_processing/model_tf/tf_lstm2_pre.py

@@ -18,6 +18,7 @@ model_lock = Lock()
 from itertools import chain
 from common.logs import Log
 from tf_lstm import TSHandler
+from common.data_utils import deep_update
 # logger = Log('tf_bp').logger()
 logger = Log('tf_ts2').logger
 np.random.seed(42)  # NumPy随机种子
@@ -37,8 +38,7 @@ def update_config():
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 2)
-    current_config.update(request_args)
-
+    current_config = deep_update(current_config, request_args)
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
     g.dh = DataHandler(logger, current_config)  # 每个请求独立实例

+ 3 - 2
models_processing/model_tf/tf_lstm2_train.py

@@ -15,6 +15,8 @@ from copy import deepcopy
 from models_processing.model_tf.tf_lstm import TSHandler
 from common.database_dml_koi import *
 from common.logs import Log
+from common.data_utils import deep_update
+
 logger = Log('tf_ts2').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm2_train——service')
@@ -32,8 +34,7 @@ def update_config():
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 2)
-    current_config.update(request_args)
-
+    current_config = deep_update(current_config, request_args)
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
     g.dh = DataHandler(logger, current_config)  # 每个请求独立实例

+ 2 - 2
models_processing/model_tf/tf_lstm3_pre.py

@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_ts3').logger
 np.random.seed(42)  # NumPy随机种子
 # tf.set_random_seed(42)  # TensorFlow随机种子
@@ -36,8 +37,7 @@ def update_config():
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 3)
     request_args['lstm_type'] = request_args.get('lstm_type', 1)
-    current_config.update(request_args)
-
+    current_config = deep_update(current_config, request_args)
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
     g.dh = DataHandler(logger, current_config)  # 每个请求独立实例

+ 2 - 1
models_processing/model_tf/tf_lstm3_train.py

@@ -14,6 +14,7 @@ import time, yaml, threading
 from copy import deepcopy
 from common.database_dml_koi import *
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_ts3').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm3_train——service')
@@ -32,7 +33,7 @@ def update_config():
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 3)
     request_args['lstm_type'] = request_args.get('lstm_type', 1)
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 2 - 1
models_processing/model_tf/tf_lstm_pre.py

@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 from tf_lstm import TSHandler
 # logger = Log('tf_bp').logger()
 logger = Log('tf_ts').logger
@@ -37,7 +38,7 @@ def update_config():
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 1)
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 2 - 1
models_processing/model_tf/tf_lstm_train.py

@@ -15,6 +15,7 @@ from copy import deepcopy
 from models_processing.model_tf.tf_lstm import TSHandler
 from common.database_dml_koi import *
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_ts').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm_train——service')
@@ -32,7 +33,7 @@ def update_config():
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 1)
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 2 - 2
models_processing/model_tf/tf_lstm_zone_pre.py

@@ -18,6 +18,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 # logger = Log('tf_bp').logger()
 logger = Log('tf_ts').logger
 np.random.seed(42)  # NumPy随机种子
@@ -38,8 +39,7 @@ def update_config():
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 1)
     request_args['zone'] = request_args['zone'].split(',')
-    current_config.update(request_args)
-
+    current_config = deep_update(current_config, request_args)
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)
     g.dh = CustomDataHandler(logger, current_config)  # 每个请求独立实例

+ 2 - 1
models_processing/model_tf/tf_lstm_zone_train.py

@@ -15,6 +15,7 @@ from copy import deepcopy
 from models_processing.model_tf.tf_lstm_zone import TSHandler
 from common.database_dml_koi import *
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_ts').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm_zone_train——service')
@@ -33,7 +34,7 @@ def update_config():
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
     request_args['time_series'] = request_args.get('time_series', 1)
     request_args['zone'] = request_args['zone'].split(',')
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 2 - 1
models_processing/model_tf/tf_test_pre.py

@@ -17,6 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
+from common.data_utils import deep_update
 from tf_test import TSHandler
 # logger = Log('tf_bp').logger()
 logger = Log('tf_test').logger
@@ -36,7 +37,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)

+ 2 - 1
models_processing/model_tf/tf_test_train.py

@@ -15,6 +15,7 @@ from copy import deepcopy
 from models_processing.model_tf.tf_test import TSHandler
 from common.database_dml_koi import *
 from common.logs import Log
+from common.data_utils import deep_update
 logger = Log('tf_test').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_test_train——service')
@@ -31,7 +32,7 @@ def update_config():
     request_args = request.values.to_dict()
     # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
     request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
-    current_config.update(request_args)
+    current_config = deep_update(current_config, request_args)
 
     # 存储到请求上下文
     g.opt = argparse.Namespace(**current_config)