tf_bp.py 4.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :nn_bp.py
  4. # @Time :2025/2/12 10:41
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from tensorflow.keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, Reshape, Flatten
  8. from tensorflow.keras.models import Model, load_model
  9. from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
  10. from tensorflow.keras import optimizers, regularizers
  11. import numpy as np
  12. from common.database_dml import *
  13. from threading import Lock
  14. model_lock = Lock()
  15. class BPHandler(object):
  16. def __init__(self, logger):
  17. self.logger = logger
  18. self.model = None
  19. def get_model(self, args):
  20. """
  21. 单例模式+线程锁,防止在异步加载时引发线程安全
  22. """
  23. try:
  24. with model_lock:
  25. # NPHandler.model = NPHandler.get_keras_model(opt)
  26. self.model = get_h5_model_from_mongo(args)
  27. except Exception as e:
  28. self.logger.info("加载模型权重失败:{}".format(e.args))
  29. @staticmethod
  30. def get_keras_model(opt):
  31. # db_loss = NorthEastLoss(opt)
  32. # south_loss = SouthLoss(opt)
  33. from models_processing.losses.loss_cdq import rmse
  34. l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
  35. l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
  36. nwp_input = Input(shape=(opt.Model['time_step'], opt.Model['input_size']), name='nwp')
  37. con1 = Conv1D(filters=64, kernel_size=1, strides=1, padding='valid', activation='relu', kernel_regularizer=l2_reg)(nwp_input)
  38. d1 = Dense(32, activation='relu', name='d1', kernel_regularizer=l1_reg)(con1)
  39. nwp = Dense(8, activation='relu', name='d2', kernel_regularizer=l1_reg)(d1)
  40. output = Dense(1, name='d5')(nwp)
  41. output_f = Flatten()(output)
  42. model = Model(nwp_input, output_f)
  43. adam = optimizers.Adam(learning_rate=opt.Model['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True)
  44. reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.01, patience=5, verbose=1)
  45. model.compile(loss=rmse, optimizer=adam)
  46. return model
  47. def train_init(self, opt):
  48. try:
  49. if opt.Model['add_train']:
  50. # 进行加强训练,支持修模
  51. base_train_model = get_h5_model_from_mongo(vars(opt))
  52. base_train_model.summary()
  53. self.logger.info("已加载加强训练基础模型")
  54. else:
  55. base_train_model = self.get_keras_model(opt)
  56. return base_train_model
  57. except Exception as e:
  58. self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
  59. def training(self, opt, train_and_valid_data):
  60. model = self.train_init(opt)
  61. # tf.reset_default_graph() # 清除默认图
  62. train_x, train_y, valid_x, valid_y = train_and_valid_data
  63. print("----------", np.array(train_x[0]).shape)
  64. print("++++++++++", np.array(train_x[1]).shape)
  65. model.summary()
  66. check_point = ModelCheckpoint(filepath='./var/' + 'fmi.h5', monitor='val_loss', save_best_only=True, mode='auto')
  67. early_stop = EarlyStopping(monitor='val_loss', patience=opt.Model['patience'], mode='auto')
  68. history = model.fit(train_x, train_y, batch_size=opt.Model['batch_size'], epochs=opt.Model['epoch'], verbose=2, validation_data=(valid_x, valid_y), callbacks=[check_point, early_stop], shuffle=False)
  69. loss = np.round(history.history['loss'], decimals=5)
  70. val_loss = np.round(history.history['val_loss'], decimals=5)
  71. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  72. self.logger.info("训练集损失函数为:{}".format(loss))
  73. self.logger.info("验证集损失函数为:{}".format(val_loss))
  74. return model
  75. def predict(self, test_x, batch_size=1):
  76. result = self.model.predict(test_x, batch_size=batch_size)
  77. self.logger.info("执行预测方法")
  78. return result
  79. if __name__ == "__main__":
  80. run_code = 0