tf_bp.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_bp.py
  4. # @Time :2025/2/13 13:34
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from tensorflow.keras.models import Sequential
  8. from tensorflow.keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, Reshape, Flatten
  9. from tensorflow.keras.models import Model, load_model
  10. from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
  11. from tensorflow.keras import optimizers, regularizers
  12. from models_processing.losses.loss_cdq import rmse
  13. import numpy as np
  14. from common.database_dml import *
  15. from threading import Lock
  16. model_lock = Lock()
  17. class BPHandler(object):
  18. def __init__(self, logger):
  19. self.logger = logger
  20. self.model = None
  21. def get_model(self, args):
  22. """
  23. 单例模式+线程锁,防止在异步加载时引发线程安全
  24. """
  25. try:
  26. with model_lock:
  27. # NPHandler.model = NPHandler.get_keras_model(opt)
  28. self.model = get_h5_model_from_mongo(args, {'rmse': rmse})
  29. except Exception as e:
  30. self.logger.info("加载模型权重失败:{}".format(e.args))
  31. @staticmethod
  32. def get_keras_model(opt):
  33. model = Sequential([
  34. Dense(64, input_dim=opt.Model['input_size'], activation='relu'), # 输入层和隐藏层,10个神经元
  35. Dense(32, activation='relu'), # 隐藏层,8个神经元
  36. Dense(1, activation='linear') # 输出层,1个神经元(用于回归任务)
  37. ])
  38. adam = optimizers.Adam(learning_rate=opt.Model['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True)
  39. model.compile(loss=rmse, optimizer=adam)
  40. return model
  41. def train_init(self, opt):
  42. try:
  43. if opt.Model['add_train']:
  44. # 进行加强训练,支持修模
  45. base_train_model = get_h5_model_from_mongo(vars(opt), {'rmse': rmse})
  46. base_train_model.summary()
  47. self.logger.info("已加载加强训练基础模型")
  48. else:
  49. base_train_model = self.get_keras_model(opt)
  50. return base_train_model
  51. except Exception as e:
  52. self.logger.info("加强训练加载模型权重失败:{}".format(e.args))
  53. def training(self, opt, train_and_valid_data):
  54. model = self.train_init(opt)
  55. # tf.reset_default_graph() # 清除默认图
  56. train_x, train_y, valid_x, valid_y = train_and_valid_data
  57. print("----------", np.array(train_x[0]).shape)
  58. print("++++++++++", np.array(train_x[1]).shape)
  59. model.summary()
  60. early_stop = EarlyStopping(monitor='val_loss', patience=opt.Model['patience'], mode='auto')
  61. history = model.fit(train_x, train_y, batch_size=opt.Model['batch_size'], epochs=opt.Model['epoch'], verbose=2, validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
  62. loss = np.round(history.history['loss'], decimals=5)
  63. val_loss = np.round(history.history['val_loss'], decimals=5)
  64. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  65. self.logger.info("训练集损失函数为:{}".format(loss))
  66. self.logger.info("验证集损失函数为:{}".format(val_loss))
  67. return model
  68. def predict(self, test_x, batch_size=1):
  69. result = self.model.predict(test_x, batch_size=batch_size)
  70. self.logger.info("执行预测方法")
  71. return result
  72. if __name__ == "__main__":
  73. run_code = 0