|
@@ -23,6 +23,7 @@ class TSHandler(object):
|
|
|
self.logger = logger
|
|
|
self.opt = argparse.Namespace(**args)
|
|
|
self.model = None
|
|
|
+ self.model_params = None
|
|
|
|
|
|
def get_model(self, args):
|
|
|
"""
|
|
@@ -31,7 +32,7 @@ class TSHandler(object):
|
|
|
try:
|
|
|
with model_lock:
|
|
|
loss = region_loss(self.opt)
|
|
|
- self.model = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
|
|
|
+ self.model, self.model_params = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
|
|
|
except Exception as e:
|
|
|
self.logger.info("加载模型权重失败:{}".format(e.args))
|
|
|
|
|
@@ -55,20 +56,17 @@ class TSHandler(object):
|
|
|
|
|
|
def train_init(self):
|
|
|
try:
|
|
|
- if self.opt.Model['add_train']:
|
|
|
- # 进行加强训练,支持修模
|
|
|
- loss = region_loss(self.opt)
|
|
|
- base_train_model = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
|
|
|
- base_train_model.summary()
|
|
|
- self.logger.info("已加载加强训练基础模型")
|
|
|
- else:
|
|
|
- base_train_model = self.get_keras_model(self.opt)
|
|
|
+ # 进行加强训练,支持修模
|
|
|
+ loss = region_loss(self.opt)
|
|
|
+ base_train_model, self.model_params = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
|
|
|
+ base_train_model.summary()
|
|
|
+ self.logger.info("已加载加强训练基础模型")
|
|
|
return base_train_model
|
|
|
except Exception as e:
|
|
|
- self.logger.info("加载模型权重失败:{}".format(e.args))
|
|
|
+ self.logger.info("加载加强训练模型权重失败:{}".format(e.args))
|
|
|
+ return False
|
|
|
|
|
|
- def training(self, train_and_valid_data):
|
|
|
- model = self.train_init()
|
|
|
+ def training(self, model, train_and_valid_data):
|
|
|
model.summary()
|
|
|
train_x, train_y, valid_x, valid_y = train_and_valid_data
|
|
|
early_stop = EarlyStopping(monitor='val_loss', patience=self.opt.Model['patience'], mode='auto')
|