tf_lstm.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_lstm.py
  4. # @Time :2025/2/12 14:03
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from tensorflow.keras.layers import Input, Dense, LSTM, Conv1D, MaxPooling1D
  8. from tensorflow.keras.models import Model
  9. from tensorflow.keras.callbacks import EarlyStopping
  10. from tensorflow.keras import optimizers, regularizers
  11. from app.model.losses import region_loss
  12. import numpy as np
  13. from app.common.dbmg import MongoUtils
  14. # from app.model.losses import rmse
  15. from threading import Lock
  16. import argparse
  17. model_lock = Lock()
  18. class TSHandler(object):
  19. def __init__(self, logger, args):
  20. self.logger = logger
  21. self.opt = args.parse_args_and_yaml()
  22. self.model = None
  23. self.model_params = None
  24. self.mongoUtils = MongoUtils(logger)
  25. def get_model(self, args):
  26. """
  27. 单例模式+线程锁,防止在异步加载时引发线程安全
  28. """
  29. try:
  30. with model_lock:
  31. loss = region_loss(self.opt)
  32. self.model, self.model_params = self.mongoUtils.get_keras_model_from_mongo(args, {type(loss).__name__: loss})
  33. except Exception as e:
  34. self.logger.info("加载模型权重失败:{}".format(e.args))
  35. @staticmethod
  36. def get_keras_model(opt):
  37. loss = region_loss(opt)
  38. l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
  39. l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
  40. nwp_input = Input(shape=(opt.Model['time_step'], opt.Model['input_size']), name='nwp')
  41. con1 = Conv1D(filters=64, kernel_size=5, strides=1, padding='valid', activation='relu', kernel_regularizer=l2_reg)(nwp_input)
  42. con1_p = MaxPooling1D(pool_size=5, strides=1, padding='valid', data_format='channels_last')(con1)
  43. nwp_lstm = LSTM(units=opt.Model['hidden_size'], return_sequences=False, kernel_regularizer=l2_reg)(con1_p)
  44. output = Dense(opt.Model['output_size'], name='cdq_output')(nwp_lstm)
  45. model = Model(nwp_input, output)
  46. adam = optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True)
  47. model.compile(loss=loss, optimizer=adam)
  48. return model
  49. def train_init(self):
  50. try:
  51. # 进行加强训练,支持修模
  52. loss = region_loss(self.opt)
  53. base_train_model, self.model_params = self.mongoUtils.get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
  54. base_train_model.summary()
  55. self.logger.info("已加载加强训练基础模型")
  56. return base_train_model
  57. except Exception as e:
  58. self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
  59. return False
  60. def training(self, model, train_and_valid_data):
  61. model.summary()
  62. train_x, train_y, valid_x, valid_y = train_and_valid_data
  63. early_stop = EarlyStopping(monitor='val_loss', patience=self.opt.Model['patience'], mode='auto')
  64. history = model.fit(train_x, train_y, batch_size=self.opt.Model['batch_size'], epochs=self.opt.Model['epoch'],
  65. verbose=2, validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
  66. loss = np.round(history.history['loss'], decimals=5)
  67. val_loss = np.round(history.history['val_loss'], decimals=5)
  68. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  69. self.logger.info("训练集损失函数为:{}".format(loss))
  70. self.logger.info("验证集损失函数为:{}".format(val_loss))
  71. return model
  72. def predict(self, test_x, batch_size=1):
  73. result = self.model.predict(test_x, batch_size=batch_size)
  74. self.logger.info("执行预测方法")
  75. return result
  76. if __name__ == "__main__":
  77. run_code = 0