|
@@ -6,11 +6,12 @@
|
|
|
# @Company: shenyang JY
|
|
|
import json, copy
|
|
|
import numpy as np
|
|
|
-from flask import Flask, request, jsonify
|
|
|
+from flask import Flask, request, jsonify, g
|
|
|
import traceback, uuid
|
|
|
import logging, argparse
|
|
|
from data_processing.data_operation.data_handler import DataHandler
|
|
|
import time, yaml, threading
|
|
|
+from copy import deepcopy
|
|
|
from models_processing.model_tf.tf_test import TSHandler
|
|
|
from common.database_dml_koi import *
|
|
|
from common.logs import Log
|
|
@@ -18,32 +19,35 @@ logger = Log('tf_test').logger
|
|
|
np.random.seed(42) # NumPy随机种子
|
|
|
app = Flask('tf_test_train——service')
|
|
|
|
|
|
-with app.app_context():
|
|
|
- current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
- with open(os.path.join(current_dir, 'test.yaml'), 'r', encoding='utf-8') as f:
|
|
|
- args = yaml.safe_load(f)
|
|
|
-
|
|
|
- dh = DataHandler(logger, args)
|
|
|
- ts = TSHandler(logger, args)
|
|
|
+current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+with open(os.path.join(current_dir, 'test.yaml'), 'r', encoding='utf-8') as f:
|
|
|
+ global_config = yaml.safe_load(f) # 只读的全局配置
|
|
|
|
|
|
@app.before_request
|
|
|
def update_config():
|
|
|
# ------------ 整理参数,整合请求参数 ------------
|
|
|
- args_dict = request.values.to_dict()
|
|
|
- args_dict['features'] = args_dict['features'].split(',')
|
|
|
- args.update(args_dict)
|
|
|
- opt = argparse.Namespace(**args)
|
|
|
- dh.opt = opt
|
|
|
- ts.opt = opt
|
|
|
- logger.info(args)
|
|
|
+ # 深拷贝全局配置 + 合并请求参数
|
|
|
+ current_config = deepcopy(global_config)
|
|
|
+ request_args = request.values.to_dict()
|
|
|
+ # features参数规则:1.有传入,解析,覆盖 2. 无传入,不覆盖,原始值
|
|
|
+ request_args['features'] = request_args['features'].split(',') if 'features' in request_args else current_config['features']
|
|
|
+ current_config.update(request_args)
|
|
|
+
|
|
|
+ # 存储到请求上下文
|
|
|
+ g.opt = argparse.Namespace(**current_config)
|
|
|
+ g.dh = DataHandler(logger, g.opt) # 每个请求独立实例
|
|
|
+ g.ts = TSHandler(logger, g.opt)
|
|
|
|
|
|
-@app.route('/nn_test_training', methods=['POST'])
|
|
|
+@app.route('/tf_test_training', methods=['POST'])
|
|
|
def model_training_test():
|
|
|
# 获取程序开始时间
|
|
|
start_time = time.time()
|
|
|
result = {}
|
|
|
success = 0
|
|
|
- print("Program starts execution!")
|
|
|
+ dh = g.dh
|
|
|
+ ts = g.ts
|
|
|
+ args = deepcopy(g.opt.__dict__)
|
|
|
+ logger.info("Program starts execution!")
|
|
|
try:
|
|
|
# ------------ 获取数据,预处理训练数据 ------------
|
|
|
train_data = get_data_from_mongo(args)
|