|
@@ -15,6 +15,8 @@ from copy import deepcopy
|
|
|
from models_processing.model_tf.tf_lstm import TSHandler
|
|
|
from common.database_dml_koi import *
|
|
|
from common.logs import Log
|
|
|
+from common.data_utils import deep_update
|
|
|
+
|
|
|
logger = Log('tf_ts2').logger
|
|
|
np.random.seed(42) # NumPy随机种子
|
|
|
app = Flask('tf_lstm2_train——service')
|
|
@@ -32,8 +34,7 @@ def update_config():
|
|
|
# features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
|
|
|
request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
|
|
|
request_args['time_series'] = request_args.get('time_series', 2)
|
|
|
- current_config.update(request_args)
|
|
|
-
|
|
|
+ current_config = deep_update(current_config, request_args)
|
|
|
# 存储到请求上下文
|
|
|
g.opt = argparse.Namespace(**current_config)
|
|
|
g.dh = DataHandler(logger, current_config) # 每个请求独立实例
|