nn_bp.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2024/5/6 13:25
  4. # file: time_series.py
  5. # author: David
  6. # company: shenyang JY
  7. import os.path
  8. from flask import Flask
  9. from keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, BatchNormalization, Flatten, Dropout, Reshape, Lambda, TimeDistributed
  10. from keras.models import Model, load_model
  11. from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
  12. from keras import optimizers, regularizers
  13. import keras.backend as K
  14. import numpy as np
  15. np.random.seed(42)
  16. from cache.sloss import SouthLoss, NorthEastLoss
  17. import tensorflow as tf
  18. tf.compat.v1.set_random_seed(1234)
  19. from threading import Lock
  20. model_lock = Lock()
  21. def rmse(y_true, y_pred):
  22. return K.sqrt(K.mean(K.square(y_pred - y_true)))
  23. def mae(y_true, y_pred):
  24. return K.mean(K.abs(y_pred - y_true), axis=-1)
  25. var_dir = os.path.dirname(os.path.dirname(__file__))
  26. class FMI(object):
  27. model = None
  28. train = False
  29. def __init__(self, log, args, graph, sess):
  30. self.logger = log
  31. self.graph = graph
  32. self.sess = sess
  33. opt = args.parse_args_and_yaml()
  34. with self.graph.as_default():
  35. tf.compat.v1.keras.backend.set_session(self.sess)
  36. FMI.get_model(opt)
  37. @staticmethod
  38. def get_model(opt):
  39. """
  40. 单例模式+线程锁,防止在异步加载时引发线程安全
  41. """
  42. try:
  43. if FMI.model is None or FMI.train is True:
  44. with model_lock:
  45. FMI.model = FMI.get_keras_model(opt)
  46. FMI.model.load_weights(os.path.join(var_dir, 'var', 'fmi.h5'))
  47. except Exception as e:
  48. print("加载模型权重失败:{}".format(e.args))
  49. @staticmethod
  50. def get_keras_model(opt):
  51. db_loss = NorthEastLoss(opt)
  52. south_loss = SouthLoss(opt)
  53. l1_reg = regularizers.l1(opt.Model['lambda_value_1'])
  54. l2_reg = regularizers.l2(opt.Model['lambda_value_2'])
  55. nwp_input = Input(shape=(opt.Model['time_step'], opt.Model['input_size_nwp']), name='nwp')
  56. env_input = Input(shape=(opt.Model['his_points'], opt.Model['input_size_env']), name='env')
  57. con1 = Conv1D(filters=64, kernel_size=1, strides=1, padding='valid', activation='relu',
  58. kernel_regularizer=l2_reg)(nwp_input)
  59. d1 = Dense(32, activation='relu', name='d1', kernel_regularizer=l1_reg)(con1)
  60. nwp = Dense(8, activation='relu', name='d2', kernel_regularizer=l1_reg)(d1)
  61. con2 = Conv1D(filters=64, kernel_size=5, strides=1, padding='valid', activation='relu', kernel_regularizer=l2_reg)(env_input)
  62. env = MaxPooling1D(pool_size=5, strides=1, padding='valid', data_format='channels_last')(con2)
  63. for i in range(opt.Model['lstm_layers']):
  64. rs = True
  65. if i == opt.Model['lstm_layers']-1:
  66. rs = False
  67. env = LSTM(units=opt.Model['hidden_size'], return_sequences=rs, name='env_lstm'+str(i), kernel_regularizer=l2_reg)(env)
  68. tiao = Dense(16, name='d4', kernel_regularizer=l1_reg)(env)
  69. if opt.Model['fusion']:
  70. nwpf = Flatten()(nwp)
  71. fusion = concatenate([nwpf, tiao])
  72. else:
  73. fusion = Flatten()(nwp)
  74. output = Dense(opt.Model['output_size'], name='d5')(fusion)
  75. model = Model([env_input, nwp_input], output)
  76. adam = optimizers.Adam(learning_rate=opt.Model['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-7,
  77. amsgrad=True)
  78. model.compile(loss=rmse, optimizer=adam)
  79. return model
  80. def train_init(self, opt):
  81. tf.compat.v1.keras.backend.set_session(self.sess)
  82. model = FMI.get_keras_model(opt)
  83. try:
  84. if opt.Model['add_train'] and opt.authentication['repair'] != "null":
  85. # 进行加强训练,支持修模
  86. model.load_weights(os.path.join(var_dir, 'var', 'fmi.h5'))
  87. self.logger.info("已加载加强训练基础模型")
  88. except Exception as e:
  89. self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
  90. model.summary()
  91. return model
  92. def training(self, opt, train_and_valid_data):
  93. model = self.train_init(opt)
  94. train_X, train_Y, valid_X, valid_Y = train_and_valid_data
  95. print("----------", np.array(train_X[0]).shape)
  96. print("++++++++++", np.array(train_X[1]).shape)
  97. # weight_lstm_1, bias_lstm_1 = model.get_layer('d1').get_weights()
  98. # print("weight_lstm_1 = ", weight_lstm_1)
  99. # print("bias_lstm_1 = ", bias_lstm_1)
  100. check_point = ModelCheckpoint(filepath=os.path.join(var_dir, 'var', 'fmi.h5'), monitor='val_loss',
  101. save_best_only=True, mode='auto')
  102. early_stop = EarlyStopping(monitor='val_loss', patience=opt.Model['patience'], mode='auto')
  103. # tbCallBack = TensorBoard(log_dir='../figure',
  104. # histogram_freq=0,
  105. # write_graph=True,
  106. # write_images=True)
  107. history = model.fit(train_X, train_Y, batch_size=opt.Model['batch_size'], epochs=opt.Model['epoch'], verbose=2,
  108. validation_data=(valid_X, valid_Y), callbacks=[check_point, early_stop])
  109. loss = np.round(history.history['loss'], decimals=5)
  110. val_loss = np.round(history.history['val_loss'], decimals=5)
  111. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  112. self.logger.info("训练集损失函数为:{}".format(loss))
  113. self.logger.info("验证集损失函数为:{}".format(val_loss))
  114. self.logger.info("训练结束,原模型地址:{}".format(id(FMI.model)))
  115. with self.graph.as_default():
  116. tf.compat.v1.keras.backend.set_session(self.sess)
  117. FMI.train = True
  118. FMI.get_model(opt)
  119. FMI.train = False
  120. self.logger.info("保护线程,加载模型,地址:{}".format(id(FMI.model)))
  121. def predict(self, test_X, batch_size=1):
  122. with self.graph.as_default():
  123. with self.sess.as_default():
  124. result = FMI.model.predict(test_X, batch_size=batch_size)
  125. self.logger.info("执行预测方法")
  126. return result