|
@@ -10,7 +10,7 @@ from tensorflow.keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Con
|
|
from tensorflow.keras.models import Model, load_model
|
|
from tensorflow.keras.models import Model, load_model
|
|
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
|
|
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
|
|
from tensorflow.keras import optimizers, regularizers
|
|
from tensorflow.keras import optimizers, regularizers
|
|
-from models_processing.losses.loss_cdq import rmse
|
|
|
|
|
|
+from models_processing.model_koi.losses import region_loss
|
|
from models_processing.model_koi.settings import set_deterministic
|
|
from models_processing.model_koi.settings import set_deterministic
|
|
import numpy as np
|
|
import numpy as np
|
|
from common.database_dml_koi import *
|
|
from common.database_dml_koi import *
|
|
@@ -31,15 +31,14 @@ class CNNHandler(object):
|
|
"""
|
|
"""
|
|
try:
|
|
try:
|
|
with model_lock:
|
|
with model_lock:
|
|
- # NPHandler.model = NPHandler.get_keras_model(opt)
|
|
|
|
- self.model = get_h5_model_from_mongo(args, {'rmse': rmse})
|
|
|
|
|
|
+ loss = region_loss(self.opt)
|
|
|
|
+ self.model = get_h5_model_from_mongo(args, {type(loss).__name__: loss})
|
|
except Exception as e:
|
|
except Exception as e:
|
|
self.logger.info("加载模型权重失败:{}".format(e.args))
|
|
self.logger.info("加载模型权重失败:{}".format(e.args))
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def get_keras_model(opt):
|
|
def get_keras_model(opt):
|
|
- # db_loss = NorthEastLoss(opt)
|
|
|
|
- # south_loss = SouthLoss(opt)
|
|
|
|
|
|
+ loss = region_loss(opt)
|
|
l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
|
|
l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
|
|
l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
|
|
l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
|
|
nwp_input = Input(shape=(opt.Model['time_step'], opt.Model['input_size']), name='nwp')
|
|
nwp_input = Input(shape=(opt.Model['time_step'], opt.Model['input_size']), name='nwp')
|
|
@@ -53,21 +52,23 @@ class CNNHandler(object):
|
|
model = Model(inputs=nwp_input, outputs=output_f)
|
|
model = Model(inputs=nwp_input, outputs=output_f)
|
|
adam = optimizers.Adam(learning_rate=opt.Model['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True)
|
|
adam = optimizers.Adam(learning_rate=opt.Model['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True)
|
|
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.01, patience=5, verbose=1)
|
|
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.01, patience=5, verbose=1)
|
|
- model.compile(loss=rmse, optimizer=adam)
|
|
|
|
|
|
+
|
|
|
|
+ model.compile(loss=loss, optimizer=adam)
|
|
return model
|
|
return model
|
|
|
|
|
|
def train_init(self):
|
|
def train_init(self):
|
|
try:
|
|
try:
|
|
if self.opt.Model['add_train']:
|
|
if self.opt.Model['add_train']:
|
|
# 进行加强训练,支持修模
|
|
# 进行加强训练,支持修模
|
|
- base_train_model = get_h5_model_from_mongo(vars(self.opt), {'rmse': rmse})
|
|
|
|
|
|
+ loss = region_loss(self.opt)
|
|
|
|
+ base_train_model = get_h5_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
|
|
base_train_model.summary()
|
|
base_train_model.summary()
|
|
self.logger.info("已加载加强训练基础模型")
|
|
self.logger.info("已加载加强训练基础模型")
|
|
else:
|
|
else:
|
|
base_train_model = self.get_keras_model(self.opt)
|
|
base_train_model = self.get_keras_model(self.opt)
|
|
return base_train_model
|
|
return base_train_model
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
|
|
|
|
|
|
+ self.logger.info("加载模型权重失败:{}".format(e.args))
|
|
|
|
|
|
def training(self, train_and_valid_data):
|
|
def training(self, train_and_valid_data):
|
|
model = self.train_init()
|
|
model = self.train_init()
|