|
@@ -13,6 +13,7 @@ from common.processing_data_common import missing_features, str_to_list
|
|
from data_processing.data_operation.data_handler import DataHandler
|
|
from data_processing.data_operation.data_handler import DataHandler
|
|
from threading import Lock
|
|
from threading import Lock
|
|
import time, yaml
|
|
import time, yaml
|
|
|
|
+from copy import deepcopy
|
|
model_lock = Lock()
|
|
model_lock = Lock()
|
|
from itertools import chain
|
|
from itertools import chain
|
|
from common.logs import Log
|
|
from common.logs import Log
|
|
@@ -22,34 +23,35 @@ np.random.seed(42) # NumPy随机种子
|
|
# tf.set_random_seed(42) # TensorFlow随机种子
|
|
# tf.set_random_seed(42) # TensorFlow随机种子
|
|
app = Flask('tf_bp_pre——service')
|
|
app = Flask('tf_bp_pre——service')
|
|
|
|
|
|
-with app.app_context():
|
|
|
|
- current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
- with open(os.path.join(current_dir, 'bp.yaml'), 'r', encoding='utf-8') as f:
|
|
|
|
- args = yaml.safe_load(f)
|
|
|
|
- dh = DataHandler(logger, args)
|
|
|
|
- bp = BPHandler(logger, args)
|
|
|
|
- global opt
|
|
|
|
|
|
+current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
+with open(os.path.join(current_dir, 'bp.yaml'), 'r', encoding='utf-8') as f:
|
|
|
|
+ global_config = yaml.safe_load(f) # 只读的全局配置
|
|
|
|
|
|
@app.before_request
|
|
@app.before_request
|
|
def update_config():
|
|
def update_config():
|
|
# ------------ 整理参数,整合请求参数 ------------
|
|
# ------------ 整理参数,整合请求参数 ------------
|
|
- args_dict = request.values.to_dict()
|
|
|
|
- if 'features' in args_dict:
|
|
|
|
- args_dict['features'] = args_dict['features'].split(',')
|
|
|
|
- args.update(args_dict)
|
|
|
|
- opt = argparse.Namespace(**args)
|
|
|
|
- dh.opt = opt
|
|
|
|
- bp.opt = opt
|
|
|
|
- g.opt = opt
|
|
|
|
- logger.info(args)
|
|
|
|
|
|
+ # 深拷贝全局配置 + 合并请求参数
|
|
|
|
+ current_config = deepcopy(global_config)
|
|
|
|
+ request_args = request.values.to_dict()
|
|
|
|
+ # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
|
|
|
|
+ request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
|
|
|
|
+ current_config.update(request_args)
|
|
|
|
|
|
-@app.route('/nn_bp_predict', methods=['POST'])
|
|
|
|
|
|
+ # 存储到请求上下文
|
|
|
|
+ g.opt = argparse.Namespace(**current_config)
|
|
|
|
+ g.dh = DataHandler(logger, g.opt) # 每个请求独立实例
|
|
|
|
+ g.bp = BPHandler(logger, g.opt)
|
|
|
|
+
|
|
|
|
+@app.route('/tf_bp_predict', methods=['POST'])
|
|
def model_prediction_bp():
|
|
def model_prediction_bp():
|
|
# 获取程序开始时间
|
|
# 获取程序开始时间
|
|
start_time = time.time()
|
|
start_time = time.time()
|
|
result = {}
|
|
result = {}
|
|
success = 0
|
|
success = 0
|
|
- print("Program starts execution!")
|
|
|
|
|
|
+ dh = g.dh
|
|
|
|
+ bp = g.bp
|
|
|
|
+ args = g.opt.__dict__
|
|
|
|
+ logger.info("Program starts execution!")
|
|
try:
|
|
try:
|
|
# ------------ 获取数据,预处理预测数据------------
|
|
# ------------ 获取数据,预处理预测数据------------
|
|
pre_data = get_data_from_mongo(args)
|
|
pre_data = get_data_from_mongo(args)
|