|
@@ -14,11 +14,10 @@ import time
|
|
from app.common.tf_fmi import FMIHandler
|
|
from app.common.tf_fmi import FMIHandler
|
|
from app.common.dbmg import MongoUtils
|
|
from app.common.dbmg import MongoUtils
|
|
from app.common.logs import logger
|
|
from app.common.logs import logger
|
|
|
|
+from copy import deepcopy
|
|
np.random.seed(42) # NumPy随机种子
|
|
np.random.seed(42) # NumPy随机种子
|
|
# tf.set_random_seed(42) # TensorFlow随机种子
|
|
# tf.set_random_seed(42) # TensorFlow随机种子
|
|
|
|
|
|
-dh = DataHandler(logger, params)
|
|
|
|
-fmi = FMIHandler(logger, params)
|
|
|
|
mgUtils = MongoUtils(logger)
|
|
mgUtils = MongoUtils(logger)
|
|
|
|
|
|
def model_training(train_data, input_file, cap):
|
|
def model_training(train_data, input_file, cap):
|
|
@@ -29,6 +28,10 @@ def model_training(train_data, input_file, cap):
|
|
farm_id = input_file.split('/')[-2]
|
|
farm_id = input_file.split('/')[-2]
|
|
output_file = input_file.replace('IN', 'OUT')
|
|
output_file = input_file.replace('IN', 'OUT')
|
|
status_file = 'STATUS.TXT'
|
|
status_file = 'STATUS.TXT'
|
|
|
|
+ # 创建线程独立的实例
|
|
|
|
+ local_params = deepcopy(params)
|
|
|
|
+ dh = DataHandler(logger, local_params)
|
|
|
|
+ fmi = FMIHandler(logger, local_params)
|
|
try:
|
|
try:
|
|
# ------------ 获取数据,预处理训练数据 ------------
|
|
# ------------ 获取数据,预处理训练数据 ------------
|
|
dh.opt.cap = cap
|
|
dh.opt.cap = cap
|
|
@@ -56,16 +59,16 @@ def model_training(train_data, input_file, cap):
|
|
# 更新算法状态:1. 启动成功
|
|
# 更新算法状态:1. 启动成功
|
|
write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite')
|
|
write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite')
|
|
# ------------ 组装模型数据 ------------
|
|
# ------------ 组装模型数据 ------------
|
|
- params['Model']['features'] = ','.join(dh.opt.features)
|
|
|
|
- params.update({
|
|
|
|
- 'params': json.dumps(params),
|
|
|
|
|
|
+ local_params['Model']['features'] = ','.join(dh.opt.features)
|
|
|
|
+ local_params.update({
|
|
|
|
+ 'params': json.dumps(local_params),
|
|
'descr': f'南网竞赛-{farm_id}',
|
|
'descr': f'南网竞赛-{farm_id}',
|
|
'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
|
|
'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
|
|
- 'model_table': params['model_table'] + farm_id,
|
|
|
|
- 'scaler_table': params['scaler_table'] + farm_id
|
|
|
|
|
|
+ 'model_table': local_params['model_table'] + farm_id,
|
|
|
|
+ 'scaler_table': local_params['scaler_table'] + farm_id
|
|
})
|
|
})
|
|
- mgUtils.insert_trained_model_into_mongo(ts_model, params)
|
|
|
|
- mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, params)
|
|
|
|
|
|
+ mgUtils.insert_trained_model_into_mongo(ts_model, local_params)
|
|
|
|
+ mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, local_params)
|
|
# 更新算法状态:正常结束
|
|
# 更新算法状态:正常结束
|
|
write_number_to_file(os.path.join(output_file, status_file), 2, 2)
|
|
write_number_to_file(os.path.join(output_file, status_file), 2, 2)
|
|
except Exception as e:
|
|
except Exception as e:
|