|
@@ -13,11 +13,13 @@ from models_processing.losses.loss_cdq import rmse
|
|
import numpy as np
|
|
import numpy as np
|
|
from common.database_dml import *
|
|
from common.database_dml import *
|
|
from threading import Lock
|
|
from threading import Lock
|
|
|
|
+import argparse
|
|
model_lock = Lock()
|
|
model_lock = Lock()
|
|
|
|
|
|
class BPHandler(object):
|
|
class BPHandler(object):
|
|
- def __init__(self, logger):
|
|
|
|
|
|
+ def __init__(self, logger, args):
|
|
self.logger = logger
|
|
self.logger = logger
|
|
|
|
+ self.opt = argparse.Namespace(**args)
|
|
self.model = None
|
|
self.model = None
|
|
|
|
|
|
def get_model(self, args):
|
|
def get_model(self, args):
|
|
@@ -42,28 +44,28 @@ class BPHandler(object):
|
|
model.compile(loss=rmse, optimizer=adam)
|
|
model.compile(loss=rmse, optimizer=adam)
|
|
return model
|
|
return model
|
|
|
|
|
|
- def train_init(self, opt):
|
|
|
|
|
|
+ def train_init(self):
|
|
try:
|
|
try:
|
|
- if opt.Model['add_train']:
|
|
|
|
|
|
+ if self.opt.Model['add_train']:
|
|
# 进行加强训练,支持修模
|
|
# 进行加强训练,支持修模
|
|
- base_train_model = get_h5_model_from_mongo(vars(opt), {'rmse': rmse})
|
|
|
|
|
|
+ base_train_model = get_h5_model_from_mongo(vars(self.opt), {'rmse': rmse})
|
|
base_train_model.summary()
|
|
base_train_model.summary()
|
|
self.logger.info("已加载加强训练基础模型")
|
|
self.logger.info("已加载加强训练基础模型")
|
|
else:
|
|
else:
|
|
- base_train_model = self.get_keras_model(opt)
|
|
|
|
|
|
+ base_train_model = self.get_keras_model(self.opt)
|
|
return base_train_model
|
|
return base_train_model
|
|
except Exception as e:
|
|
except Exception as e:
|
|
self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
|
|
self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
|
|
|
|
|
|
- def training(self, opt, train_and_valid_data):
|
|
|
|
- model = self.train_init(opt)
|
|
|
|
|
|
+ def training(self, train_and_valid_data):
|
|
|
|
+ model = self.train_init()
|
|
# tf.reset_default_graph() # 清除默认图
|
|
# tf.reset_default_graph() # 清除默认图
|
|
train_x, train_y, valid_x, valid_y = train_and_valid_data
|
|
train_x, train_y, valid_x, valid_y = train_and_valid_data
|
|
print("----------", np.array(train_x[0]).shape)
|
|
print("----------", np.array(train_x[0]).shape)
|
|
print("++++++++++", np.array(train_x[1]).shape)
|
|
print("++++++++++", np.array(train_x[1]).shape)
|
|
model.summary()
|
|
model.summary()
|
|
- early_stop = EarlyStopping(monitor='val_loss', patience=opt.Model['patience'], mode='auto')
|
|
|
|
- history = model.fit(train_x, train_y, batch_size=opt.Model['batch_size'], epochs=opt.Model['epoch'], verbose=2, validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
|
|
|
|
|
|
+ early_stop = EarlyStopping(monitor='val_loss', patience=self.opt.Model['patience'], mode='auto')
|
|
|
|
+ history = model.fit(train_x, train_y, batch_size=self.opt.Model['batch_size'], epochs=self.opt.Model['epoch'], verbose=2, validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
|
|
loss = np.round(history.history['loss'], decimals=5)
|
|
loss = np.round(history.history['loss'], decimals=5)
|
|
val_loss = np.round(history.history['val_loss'], decimals=5)
|
|
val_loss = np.round(history.history['val_loss'], decimals=5)
|
|
self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
|
|
self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
|