tf_test.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_lstm.py
  4. # @Time :2025/2/12 14:03
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from tensorflow.keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, Reshape, Flatten
  8. from tensorflow.keras.models import Model, load_model
  9. from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
  10. from tensorflow.keras import optimizers, regularizers
  11. from tensorflow.keras.layers import BatchNormalization, GlobalAveragePooling1D, Dropout, Add, Concatenate, Multiply
  12. from models_processing.model_tf.losses import region_loss
  13. import numpy as np
  14. from common.database_dml_koi import *
  15. from models_processing.model_tf.settings import set_deterministic
  16. from threading import Lock
  17. import argparse
  18. model_lock = Lock()
  19. set_deterministic(42)
  20. class TSHandler(object):
  21. def __init__(self, logger, args):
  22. self.logger = logger
  23. self.opt = argparse.Namespace(**args)
  24. self.model = None
  25. def get_model(self, args):
  26. """
  27. 单例模式+线程锁,防止在异步加载时引发线程安全
  28. """
  29. try:
  30. with model_lock:
  31. loss = region_loss(self.opt)
  32. self.model = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
  33. except Exception as e:
  34. self.logger.info("加载模型权重失败:{}".format(e.args))
  35. @staticmethod
  36. def get_keras_model(opt):
  37. """优化后的新能源功率预测模型
  38. 主要改进点:
  39. 1. 多尺度特征提取
  40. 2. 注意力机制
  41. 3. 残差连接
  42. 4. 弹性正则化
  43. 5. 自适应学习率调整
  44. """
  45. # 正则化配置
  46. l1_l2_reg = regularizers.l1_l2(
  47. l1=opt.Model['lambda_value_1'],
  48. l2=opt.Model['lambda_value_2']
  49. )
  50. # 输入层
  51. nwp_input = Input(shape=(opt.Model['time_step'], opt.Model['input_size']), name='nwp_input')
  52. # %% 多尺度特征提取模块
  53. def multi_scale_block(input_layer):
  54. # 并行卷积路径
  55. conv3 = Conv1D(64, 3, padding='causal', activation='relu')(input_layer)
  56. conv5 = Conv1D(64, 5, padding='causal', activation='relu')(input_layer)
  57. return Concatenate()([conv3, conv5])
  58. # 特征主干
  59. x = multi_scale_block(nwp_input)
  60. # %% 残差注意力模块
  61. def residual_attention_block(input_layer, filters):
  62. # 主路径
  63. y = Conv1D(filters, 3, padding='same', activation='relu')(input_layer)
  64. y = BatchNormalization()(y)
  65. # 注意力门控
  66. attention = Dense(filters, activation='sigmoid')(y)
  67. y = Multiply()([y, attention])
  68. # 残差连接
  69. shortcut = Conv1D(filters, 1, padding='same')(input_layer)
  70. return Add()([y, shortcut])
  71. x = residual_attention_block(x, 128)
  72. x = Dropout(0.3)(x)
  73. # %% 特征聚合
  74. x = GlobalAveragePooling1D()(x) # 替代Flatten保留时序特征
  75. # %% 深度可调全连接层
  76. x = Dense(256, activation='swish', kernel_regularizer=l1_l2_reg)(x)
  77. x = BatchNormalization()(x)
  78. x = Dropout(0.5)(x)
  79. # %% 输出层(可扩展为概率预测)
  80. output = Dense(1, activation='linear', name='main_output')(x)
  81. # 概率预测扩展(可选)
  82. # variance = Dense(1, activation='softplus')(x) # 输出方差
  83. # output = Concatenate()([output, variance])
  84. # %% 模型编译
  85. model = Model(inputs=nwp_input, outputs=output)
  86. # 自适应优化器配置
  87. adam = optimizers.Adam(
  88. learning_rate=opt.Model['learning_rate'],
  89. beta_1=0.92, # 调整动量参数
  90. beta_2=0.999,
  91. epsilon=1e-07,
  92. amsgrad=True
  93. )
  94. # 编译配置(假设region_loss已定义)
  95. model.compile(
  96. loss=region_loss(opt), # 自定义损失函数
  97. optimizer=adam,
  98. metrics=['mae', 'mse'] # 监控指标
  99. )
  100. return model
  101. def train_init(self):
  102. try:
  103. if self.opt.Model['add_train']:
  104. # 进行加强训练,支持修模
  105. loss = region_loss(self.opt)
  106. base_train_model = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
  107. base_train_model.summary()
  108. self.logger.info("已加载加强训练基础模型")
  109. else:
  110. base_train_model = self.get_keras_model(self.opt)
  111. return base_train_model
  112. except Exception as e:
  113. self.logger.info("加载模型权重失败:{}".format(e.args))
  114. def training(self, train_and_valid_data):
  115. model = self.train_init()
  116. model.summary()
  117. train_x, train_y, valid_x, valid_y = train_and_valid_data
  118. # 回调函数配置
  119. callbacks = [
  120. EarlyStopping(monitor='val_loss', patience=self.opt.Model['patience'], restore_best_weights=True),
  121. ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=8, min_lr=1e-7)
  122. ]
  123. history = model.fit(train_x, train_y, batch_size=self.opt.Model['batch_size'], epochs=self.opt.Model['epoch'],
  124. verbose=2, validation_data=(valid_x, valid_y), callbacks=callbacks, shuffle=False)
  125. loss = np.round(history.history['loss'], decimals=5)
  126. val_loss = np.round(history.history['val_loss'], decimals=5)
  127. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  128. self.logger.info("训练集损失函数为:{}".format(loss))
  129. self.logger.info("验证集损失函数为:{}".format(val_loss))
  130. return model
  131. def predict(self, test_x, batch_size=1):
  132. result = self.model.predict(test_x, batch_size=batch_size)
  133. self.logger.info("执行预测方法")
  134. return result
  135. if __name__ == "__main__":
  136. run_code = 0