Browse Source

多线程更新04031518

David 2 tháng trước cách đây
mục cha
commit
a4482b71f2

+ 20 - 18
models_processing/model_tf/tf_bp_pre.py

@@ -13,6 +13,7 @@ from common.processing_data_common import missing_features, str_to_list
 from data_processing.data_operation.data_handler import DataHandler
 from threading import Lock
 import time, yaml
+from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
@@ -22,34 +23,35 @@ np.random.seed(42)  # NumPy随机种子
 # tf.set_random_seed(42)  # TensorFlow随机种子
 app = Flask('tf_bp_pre——service')
 
-with app.app_context():
-    current_dir = os.path.dirname(os.path.abspath(__file__))
-    with open(os.path.join(current_dir, 'bp.yaml'), 'r', encoding='utf-8') as f:
-        args = yaml.safe_load(f)
-    dh = DataHandler(logger, args)
-    bp = BPHandler(logger, args)
-    global opt
+current_dir = os.path.dirname(os.path.abspath(__file__))
+with open(os.path.join(current_dir, 'bp.yaml'), 'r', encoding='utf-8') as f:
+    global_config = yaml.safe_load(f)  # 只读的全局配置
 
 @app.before_request
 def update_config():
     # ------------ 整理参数,整合请求参数 ------------
-    args_dict = request.values.to_dict()
-    if 'features' in args_dict:
-        args_dict['features'] = args_dict['features'].split(',')
-    args.update(args_dict)
-    opt = argparse.Namespace(**args)
-    dh.opt = opt
-    bp.opt = opt
-    g.opt = opt
-    logger.info(args)
+    # 深拷贝全局配置 + 合并请求参数
+    current_config = deepcopy(global_config)
+    request_args = request.values.to_dict()
+    # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
+    request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
+    current_config.update(request_args)
 
-@app.route('/nn_bp_predict', methods=['POST'])
+    # 存储到请求上下文
+    g.opt = argparse.Namespace(**current_config)
+    g.dh = DataHandler(logger, g.opt)  # 每个请求独立实例
+    g.bp = BPHandler(logger, g.opt)
+
+@app.route('/tf_bp_predict', methods=['POST'])
 def model_prediction_bp():
     # 获取程序开始时间
     start_time = time.time()
     result = {}
     success = 0
-    print("Program starts execution!")
+    dh = g.dh
+    bp = g.bp
+    args = g.opt.__dict__
+    logger.info("Program starts execution!")
     try:
         # ------------ 获取数据,预处理预测数据------------
         pre_data = get_data_from_mongo(args)

+ 23 - 17
models_processing/model_tf/tf_bp_train.py

@@ -5,13 +5,14 @@
 # @Author    :David
 # @Company: shenyang JY
 
-import json, copy
+import json
 import numpy as np
-from flask import Flask, request
+from flask import Flask, request, g
 import traceback
 import logging, argparse
 from data_processing.data_operation.data_handler import DataHandler
 import time, yaml
+from copy import deepcopy
 from models_processing.model_tf.tf_bp import BPHandler
 from common.database_dml_koi import *
 import matplotlib.pyplot as plt
@@ -20,31 +21,36 @@ logger = Log('tf_bp').logger
 np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_bp_train——service')
 
-with app.app_context():
-    current_dir = os.path.dirname(os.path.abspath(__file__))
-    with open(os.path.join(current_dir, 'bp.yaml'), 'r', encoding='utf-8') as f:
-        args = yaml.safe_load(f)
-    dh = DataHandler(logger, args)
-    bp = BPHandler(logger, args)
+
+current_dir = os.path.dirname(os.path.abspath(__file__))
+with open(os.path.join(current_dir, 'bp.yaml'), 'r', encoding='utf-8') as f:
+    global_config = yaml.safe_load(f)  # 只读的全局配置
 
 @app.before_request
 def update_config():
     # ------------ 整理参数,整合请求参数 ------------
-    args_dict = request.values.to_dict()
-    args_dict['features'] = args_dict['features'].split(',')
-    args.update(args_dict)
-    opt = argparse.Namespace(**args)
-    dh.opt = opt
-    bp.opt = opt
-    logger.info(args)
+    # 深拷贝全局配置 + 合并请求参数
+    current_config = deepcopy(global_config)
+    request_args = request.values.to_dict()
+    # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
+    request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
+    current_config.update(request_args)
+
+    # 存储到请求上下文
+    g.opt = argparse.Namespace(**current_config)
+    g.dh = DataHandler(logger, g.opt)  # 每个请求独立实例
+    g.bp = BPHandler(logger, g.opt)
 
-@app.route('/nn_bp_training', methods=['POST'])
+@app.route('/tf_bp_training', methods=['POST'])
 def model_training_bp():
     # 获取程序开始时间
     start_time = time.time()
     result = {}
     success = 0
-    print("Program starts execution!")
+    dh = g.dh
+    bp = g.bp
+    args = deepcopy(g.opt.__dict__)
+    logger.info("Program starts execution!")
     try:
         # ------------ 获取数据,预处理训练数据 ------------
         train_data = get_data_from_mongo(args)