tf_lstm.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_lstm.py
  4. # @Time :2025/2/12 14:03
  5. # @Author :David
  6. # @Company: shenyang JY
  7. import os.path
  8. from keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, BatchNormalization, Flatten, Dropout
  9. from keras.models import Model, load_model
  10. from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
  11. from keras import optimizers, regularizers
  12. import keras.backend as K
  13. from common.database_dml import *
  14. import numpy as np
  15. from sqlalchemy.ext.instrumentation import find_native_user_instrumentation_hook
  16. np.random.seed(42)
  17. from models_processing.losses.loss_cdq import SouthLoss, NorthEastLoss
  18. import tensorflow as tf
  19. tf.compat.v1.set_random_seed(1234)
  20. from threading import Lock
  21. model_lock = Lock()
  22. def rmse(y_true, y_pred):
  23. return K.sqrt(K.mean(K.square(y_pred - y_true)))
  24. var_dir = os.path.dirname(os.path.dirname(__file__))
  25. class TSHandler(object):
  26. model = None
  27. train = False
  28. def __init__(self, logger):
  29. self.logger = logger
  30. self.model = None
  31. def get_model(self, args):
  32. """
  33. 单例模式+线程锁,防止在异步加载时引发线程安全
  34. """
  35. try:
  36. with model_lock:
  37. self.model = get_h5_model_from_mongo(args, {'rmse': rmse})
  38. except Exception as e:
  39. self.logger.info("加载模型权重失败:{}".format(e.args))
  40. @staticmethod
  41. def get_keras_model(opt):
  42. # db_loss = NorthEastLoss(opt)
  43. south_loss = SouthLoss(opt)
  44. l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
  45. l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
  46. nwp_input = Input(shape=(opt.Model['time_step'], opt.Model['input_size_nwp']), name='nwp')
  47. con1 = Conv1D(filters=64, kernel_size=5, strides=1, padding='valid', activation='relu', kernel_regularizer=l2_reg)(nwp_input)
  48. nwp = MaxPooling1D(pool_size=5, strides=1, padding='valid', data_format='channels_last')(con1)
  49. nwp_lstm = LSTM(units=opt.Model['hidden_size'], return_sequences=False, kernel_regularizer=l2_reg)(nwp)
  50. output = Dense(opt.Model['output_size'], name='cdq_output')(nwp_lstm)
  51. model = Model(nwp_input, output)
  52. adam = optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True)
  53. model.compile(loss=south_loss, optimizer=adam)
  54. return model
  55. def train_init(self, opt):
  56. try:
  57. if opt.Model['add_train']:
  58. # 进行加强训练,支持修模
  59. base_train_model = get_h5_model_from_mongo(vars(opt), {'rmse': rmse})
  60. base_train_model.summary()
  61. self.logger.info("已加载加强训练基础模型")
  62. else:
  63. base_train_model = self.get_keras_model(opt)
  64. return base_train_model
  65. except Exception as e:
  66. self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
  67. def training(self, opt, train_and_valid_data):
  68. model = self.train_init(opt)
  69. model.summary()
  70. train_x, train_y, valid_x, valid_y = train_and_valid_data
  71. early_stop = EarlyStopping(monitor='val_loss', patience=opt.Model['patience'], mode='auto')
  72. history = model.fit(train_x, train_y, batch_size=opt.Model['batch_size'], epochs=opt.Model['epoch'], verbose=2, validation_data=(valid_x, valid_y), callbacks=[early_stop])
  73. loss = np.round(history.history['loss'], decimals=5)
  74. val_loss = np.round(history.history['val_loss'], decimals=5)
  75. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  76. self.logger.info("训练集损失函数为:{}".format(loss))
  77. self.logger.info("验证集损失函数为:{}".format(val_loss))
  78. return model
  79. def predict(self, test_X, batch_size=1):
  80. result = TSHandler.model.predict(test_X, batch_size=batch_size)
  81. self.logger.info("执行预测方法")
  82. return result
  83. if __name__ == "__main__":
  84. run_code = 0