tf_bp.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_bp.py
  4. # @Time :2025/2/13 13:34
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from tensorflow.keras.models import Sequential
  8. from tensorflow.keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, Reshape, Flatten, Dropout, BatchNormalization, LeakyReLU
  9. from tensorflow.keras.models import Model, load_model
  10. from tensorflow.keras.regularizers import l2
  11. from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
  12. from tensorflow.keras import optimizers, regularizers
  13. from models_processing.model_tf.losses import region_loss
  14. from models_processing.model_tf.settings import set_deterministic
  15. import numpy as np
  16. from common.database_dml_koi import *
  17. from threading import Lock
  18. import argparse
  19. model_lock = Lock()
  20. set_deterministic(42)
  21. class BPHandler(object):
  22. def __init__(self, logger, args):
  23. self.logger = logger
  24. self.opt = argparse.Namespace(**args)
  25. self.model = None
  26. self.model_params = None
  27. def get_model(self, args):
  28. """
  29. 单例模式+线程锁,防止在异步加载时引发线程安全
  30. """
  31. try:
  32. with model_lock:
  33. # loss = region_loss(self.opt)
  34. self.model, self.model_params = get_keras_model_from_mongo(args)
  35. except Exception as e:
  36. self.logger.info("加载模型权重失败:{}".format(e.args))
  37. # @staticmethod
  38. # def get_keras_model(opt):
  39. # loss = region_loss(opt)
  40. # model = Sequential([
  41. # Dense(64, input_dim=opt.Model['input_size'], activation='relu'), # 输入层和隐藏层,10个神经元
  42. # Dropout(0.2),
  43. # Dense(32, activation='relu'), # 隐藏层,8个神经元
  44. # Dropout(0.3),
  45. # Dense(16, activation='relu'), # 隐藏层,8个神经元
  46. # Dense(1, activation='linear') # 输出层,1个神经元(用于回归任务)
  47. # ])
  48. # adam = optimizers.Adam(learning_rate=opt.Model['learning_rate'], beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True)
  49. # model.compile(loss=loss, optimizer=adam)
  50. # return model
  51. @staticmethod
  52. def get_keras_model(opt):
  53. # 自定义损失函数(需确保正确性)
  54. # loss = region_loss(opt)
  55. # 网络结构
  56. model = Sequential([
  57. Dense(128, input_dim=opt.Model['input_size'], kernel_regularizer=l2(0.001)),
  58. LeakyReLU(alpha=0.1),
  59. BatchNormalization(),
  60. Dropout(0.3),
  61. Dense(64, kernel_regularizer=l2(0.001)),
  62. LeakyReLU(alpha=0.1),
  63. BatchNormalization(),
  64. Dropout(0.4),
  65. Dense(32, kernel_regularizer=l2(0.001)),
  66. LeakyReLU(alpha=0.1),
  67. Dense(1, activation='linear')
  68. ])
  69. # 优化器配置
  70. adam = optimizers.Adam(
  71. learning_rate=0.001,
  72. beta_1=0.9,
  73. beta_2=0.999,
  74. epsilon=1e-6
  75. )
  76. model.compile(loss=tf.keras.losses.MeanSquaredError(), optimizer=adam)
  77. return model
  78. def train_init(self):
  79. try:
  80. # 进行加强训练,支持修模
  81. loss = region_loss(self.opt)
  82. base_train_model, self.model_params = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
  83. base_train_model.summary()
  84. self.logger.info("已加载加强训练基础模型")
  85. return base_train_model
  86. except Exception as e:
  87. self.logger.info("加载加强训练模型权重失败:{}".format(e.args))
  88. return False
  89. def training(self, model, train_and_valid_data):
  90. train_x, train_y, valid_x, valid_y = train_and_valid_data
  91. print("----------", np.array(train_x[0]).shape)
  92. print("++++++++++", np.array(train_x[1]).shape)
  93. model.summary()
  94. early_stop = EarlyStopping(monitor='val_loss', patience=self.opt.Model['patience'], mode='auto')
  95. history = model.fit(train_x, train_y, batch_size=self.opt.Model['batch_size'], epochs=self.opt.Model['epoch'], verbose=2, validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
  96. loss = np.round(history.history['loss'], decimals=5)
  97. val_loss = np.round(history.history['val_loss'], decimals=5)
  98. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  99. self.logger.info("训练集损失函数为:{}".format(loss))
  100. self.logger.info("验证集损失函数为:{}".format(val_loss))
  101. return model
  102. def predict(self, test_x, batch_size=1):
  103. result = self.model.predict(test_x, batch_size=batch_size)
  104. self.logger.info("执行预测方法")
  105. return result
  106. if __name__ == "__main__":
  107. run_code = 0