tf_lstm.py 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_lstm.py
  4. # @Time :2025/2/12 14:03
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from tensorflow.keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, Reshape, Flatten
  8. from tensorflow.keras.models import Model, load_model
  9. from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
  10. from tensorflow.keras import optimizers, regularizers
  11. from models_processing.losses.loss_cdq import rmse
  12. import numpy as np
  13. from common.database_dml_koi import *
  14. from models_processing.model_koi.settings import set_deterministic
  15. from threading import Lock
  16. import argparse
  17. model_lock = Lock()
  18. set_deterministic(42)
  19. class TSHandler(object):
  20. def __init__(self, logger, args):
  21. self.logger = logger
  22. self.opt = argparse.Namespace(**args)
  23. self.model = None
  24. def get_model(self, args):
  25. """
  26. 单例模式+线程锁,防止在异步加载时引发线程安全
  27. """
  28. try:
  29. with model_lock:
  30. self.model = get_h5_model_from_mongo(args, {'rmse': rmse})
  31. except Exception as e:
  32. self.logger.info("加载模型权重失败:{}".format(e.args))
  33. @staticmethod
  34. def get_keras_model(opt):
  35. # db_loss = NorthEastLoss(opt)
  36. # south_loss = SouthLoss(opt)
  37. l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
  38. l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
  39. nwp_input = Input(shape=(opt.Model['time_step'], opt.Model['input_size']), name='nwp')
  40. con1 = Conv1D(filters=64, kernel_size=5, strides=1, padding='valid', activation='relu', kernel_regularizer=l2_reg)(nwp_input)
  41. con1_p = MaxPooling1D(pool_size=5, strides=1, padding='valid', data_format='channels_last')(con1)
  42. nwp_lstm = LSTM(units=opt.Model['hidden_size'], return_sequences=False, kernel_regularizer=l2_reg)(con1_p)
  43. output = Dense(opt.Model['output_size'], name='cdq_output')(nwp_lstm)
  44. model = Model(nwp_input, output)
  45. adam = optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True)
  46. model.compile(loss=rmse, optimizer=adam)
  47. return model
  48. def train_init(self):
  49. try:
  50. if self.opt.Model['add_train']:
  51. # 进行加强训练,支持修模
  52. base_train_model = get_h5_model_from_mongo(vars(self.opt), {'rmse': rmse})
  53. base_train_model.summary()
  54. self.logger.info("已加载加强训练基础模型")
  55. else:
  56. base_train_model = self.get_keras_model(self.opt)
  57. return base_train_model
  58. except Exception as e:
  59. self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
  60. def training(self, train_and_valid_data):
  61. model = self.train_init()
  62. model.summary()
  63. train_x, train_y, valid_x, valid_y = train_and_valid_data
  64. early_stop = EarlyStopping(monitor='val_loss', patience=self.opt.Model['patience'], mode='auto')
  65. history = model.fit(train_x, train_y, batch_size=self.opt.Model['batch_size'], epochs=self.opt.Model['epoch'],
  66. verbose=2, validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
  67. loss = np.round(history.history['loss'], decimals=5)
  68. val_loss = np.round(history.history['val_loss'], decimals=5)
  69. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  70. self.logger.info("训练集损失函数为:{}".format(loss))
  71. self.logger.info("验证集损失函数为:{}".format(val_loss))
  72. return model
  73. def predict(self, test_x, batch_size=1):
  74. result = self.model.predict(test_x, batch_size=batch_size)
  75. self.logger.info("执行预测方法")
  76. return result
  77. if __name__ == "__main__":
  78. run_code = 0