|
@@ -8,11 +8,12 @@ import json, copy
|
|
|
import numpy as np
|
|
|
from flask import Flask, request, g
|
|
|
import logging, argparse, traceback
|
|
|
-from common.database_dml import *
|
|
|
+from common.database_dml_koi import *
|
|
|
from common.processing_data_common import missing_features, str_to_list
|
|
|
from data_processing.data_operation.data_handler import DataHandler
|
|
|
from threading import Lock
|
|
|
import time, yaml, os
|
|
|
+from copy import deepcopy
|
|
|
model_lock = Lock()
|
|
|
from itertools import chain
|
|
|
from common.logs import Log
|
|
@@ -23,34 +24,35 @@ np.random.seed(42) # NumPy随机种子
|
|
|
# tf.set_random_seed(42) # TensorFlow随机种子
|
|
|
app = Flask('tf_cnn_pre——service')
|
|
|
|
|
|
-with app.app_context():
|
|
|
- current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
- with open(os.path.join(current_dir, 'cnn.yaml'), 'r', encoding='utf-8') as f:
|
|
|
- args = yaml.safe_load(f)
|
|
|
-
|
|
|
- dh = DataHandler(logger, args)
|
|
|
- cnn = CNNHandler(logger, args)
|
|
|
+current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+with open(os.path.join(current_dir, 'cnn.yaml'), 'r', encoding='utf-8') as f:
|
|
|
+ global_config = yaml.safe_load(f) # 只读的全局配置
|
|
|
|
|
|
@app.before_request
|
|
|
def update_config():
|
|
|
# ------------ 整理参数,整合请求参数 ------------
|
|
|
- args_dict = request.values.to_dict()
|
|
|
- if 'features' in args_dict:
|
|
|
- args_dict['features'] = args_dict['features'].split(',')
|
|
|
- args.update(args_dict)
|
|
|
- opt = argparse.Namespace(**args)
|
|
|
- dh.opt = opt
|
|
|
- cnn.opt = opt
|
|
|
- g.opt = opt
|
|
|
- logger.info(args)
|
|
|
+ # 深拷贝全局配置 + 合并请求参数
|
|
|
+ current_config = deepcopy(global_config)
|
|
|
+ request_args = request.values.to_dict()
|
|
|
+ # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
|
|
|
+ request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
|
|
|
+ current_config.update(request_args)
|
|
|
+
|
|
|
+ # 存储到请求上下文
|
|
|
+ g.opt = argparse.Namespace(**current_config)
|
|
|
+ g.dh = DataHandler(logger, g.opt) # 每个请求独立实例
|
|
|
+ g.cnn = CNNHandler(logger, g.opt)
|
|
|
|
|
|
-@app.route('/nn_cnn_predict', methods=['POST'])
|
|
|
+@app.route('/tf_cnn_predict', methods=['POST'])
|
|
|
def model_prediction_bp():
|
|
|
# 获取程序开始时间
|
|
|
start_time = time.time()
|
|
|
result = {}
|
|
|
success = 0
|
|
|
- print("Program starts execution!")
|
|
|
+ dh = g.dh
|
|
|
+ cnn = g.cnn
|
|
|
+ args = g.opt.__dict__
|
|
|
+ logger.info("Program starts execution!")
|
|
|
try:
|
|
|
pre_data = get_data_from_mongo(args)
|
|
|
if args.get('algorithm_test', 0):
|