David 2 weeks ago
parent
commit
3571e6575a
1 changed files with 23 additions and 1 deletions
  1. 23 1
      models_processing/model_tf/tf_lstm.py

+ 23 - 1
models_processing/model_tf/tf_lstm.py

@@ -37,7 +37,7 @@ class TSHandler(object):
             self.logger.info("加载模型权重失败:{}".format(e.args))
 
     @staticmethod
-    def get_keras_model(opt, time_series=1, lstm_type=1):
+    def get_keras_model_20250514(opt, time_series=1, lstm_type=1):
         loss = region_loss(opt)
         l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
         l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
@@ -56,6 +56,28 @@ class TSHandler(object):
         model.compile(loss=loss, optimizer=adam)
         return model
 
+    @staticmethod
+    def get_keras_model(opt, time_series=1, lstm_type=1):
+        loss = region_loss(opt)
+        l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
+        l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
+        nwp_input = Input(shape=(opt.Model['time_step'] * time_series, opt.Model['input_size']), name='nwp')
+
+        con1 = Conv1D(filters=64, kernel_size=1, strides=1, padding='valid', activation='relu',
+                      kernel_regularizer=l2_reg)(nwp_input)
+        con1_p = MaxPooling1D(pool_size=1, strides=1, padding='valid', data_format='channels_last')(con1)
+        nwp_lstm = LSTM(units=opt.Model['hidden_size'], return_sequences=False, kernel_regularizer=l2_reg)(con1_p)
+        if lstm_type == 2:
+            output = Dense(opt.Model['time_step'], name='cdq_output')(nwp_lstm)
+        else:
+            output = Dense(opt.Model['time_step'] * time_series, name='cdq_output')(nwp_lstm)
+
+        model = Model(nwp_input, output)
+        adam = optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True)
+        model.compile(loss=loss, optimizer=adam)
+
+        return model
+
     def train_init(self):
         try:
             # 进行加强训练,支持修模