tf_transformer.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_transformer.py
  4. # @Time :2025/5/08 14:03
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from tensorflow.keras.initializers import glorot_uniform, orthogonal
  8. from tensorflow.keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, Reshape, Flatten, LayerNormalization, Dropout, Layer, Add, MultiHeadAttention, Dropout
  9. from tensorflow.keras.models import Model, load_model
  10. from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
  11. from tensorflow.keras import optimizers, regularizers
  12. from models_processing.model_tf.losses import region_loss
  13. import numpy as np
  14. from common.database_dml import *
  15. from models_processing.model_tf.settings import set_deterministic
  16. from threading import Lock
  17. import argparse
  18. model_lock = Lock()
  19. set_deterministic(42)
  20. class PositionalEncoding(tf.keras.layers.Layer):
  21. """自定义位置编码层(支持序列化)"""
  22. def __init__(self, max_len, d_model, **kwargs):
  23. super().__init__(**kwargs)
  24. self.max_len = max_len # 将参数保存为实例属性
  25. self.d_model = d_model
  26. # 位置编码在初始化时生成
  27. self.position_embedding = self.positional_encoding(max_len, d_model)
  28. def get_angles(self, pos, i, d_model):
  29. # 计算角度参数
  30. angles = 1 / tf.pow(10000., (2 * (i // 2)) / tf.cast(d_model, tf.float32))
  31. return pos * angles
  32. def positional_encoding(self, max_len, d_model):
  33. # 生成位置编码矩阵
  34. angle_rads = self.get_angles(
  35. pos=tf.range(max_len, dtype=tf.float32)[:, tf.newaxis],
  36. i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
  37. d_model=d_model
  38. )
  39. # 拼接正弦和余弦编码
  40. sines = tf.math.sin(angle_rads[:, 0::2])
  41. cosines = tf.math.cos(angle_rads[:, 1::2])
  42. pos_encoding = tf.concat([sines, cosines], axis=-1)
  43. return pos_encoding[tf.newaxis, ...] # 增加批次维度
  44. def call(self, inputs):
  45. # 动态截取与输入序列长度匹配的部分
  46. seq_len = tf.shape(inputs)[1]
  47. return inputs + self.position_embedding[:, :seq_len, :]
  48. def get_config(self):
  49. # 将参数序列化(关键步骤!)
  50. config = super().get_config()
  51. config.update({
  52. 'max_len': self.max_len,
  53. 'd_model': self.d_model,
  54. })
  55. return config
  56. class TransformerHandler(object):
  57. def __init__(self, logger, args):
  58. self.logger = logger
  59. self.opt = argparse.Namespace(**args)
  60. self.model = None
  61. self.model_params = None
  62. def get_model(self, args):
  63. """
  64. 单例模式+线程锁,防止在异步加载时引发线程安全
  65. """
  66. try:
  67. with model_lock:
  68. loss = region_loss(self.opt)
  69. self.model, self.model_params = get_keras_model_from_mongo(args, {type(loss).__name__: loss, 'PositionalEncoding': PositionalEncoding})
  70. except Exception as e:
  71. self.logger.info("加载模型权重失败:{}".format(e.args))
  72. @staticmethod
  73. def get_transformer_model(opt, time_series=1):
  74. hidden_size = opt.Model.get('hidden_size', 64)
  75. num_heads = opt.Model.get('num_heads', 4)
  76. ff_dim = opt.Model.get('ff_dim', 128)
  77. l2_reg = regularizers.l2(opt.Model.get('lambda_value_2', 0.01))
  78. nwp_input = Input(shape=(opt.Model['time_step'] * time_series, opt.Model['input_size']))
  79. # 嵌入层 + 位置编码
  80. x = Conv1D(hidden_size, kernel_size=3, padding='same', kernel_regularizer=l2_reg)(nwp_input)
  81. x = PositionalEncoding(opt.Model['time_step'], hidden_size)(x)
  82. # Transformer编码层(带残差连接)
  83. for _ in range(opt.Model.get('num_layers', 2)):
  84. # 自注意力
  85. residual = x
  86. x = MultiHeadAttention(num_heads=num_heads, key_dim=hidden_size)(x, x)
  87. x = Dropout(0.1)(x)
  88. x = Add()([residual, x])
  89. x = LayerNormalization()(x)
  90. # 前馈网络
  91. residual = x
  92. x = Dense(ff_dim, activation='relu')(x)
  93. x = Dense(hidden_size)(x)
  94. x = Dropout(0.1)(x)
  95. x = Add()([residual, x])
  96. x = LayerNormalization()(x)
  97. # 输出层(预测每个时间步)
  98. output = Dense(1, activation='linear')(x)
  99. # output = tf.keras.layers.Lambda(lambda x: tf.squeeze(x, axis=-1))(output)
  100. output = Flatten(name='Flatten')(output)
  101. model = Model(nwp_input, output)
  102. model.compile(loss='mse', optimizer=optimizers.Adam(learning_rate=1e-4))
  103. return model
  104. def train_init(self):
  105. try:
  106. # 进行加强训练,支持修模
  107. loss = region_loss(self.opt)
  108. base_train_model, self.model_params = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss, 'PositionalEncoding': PositionalEncoding})
  109. base_train_model.summary()
  110. self.logger.info("已加载加强训练基础模型")
  111. return base_train_model
  112. except Exception as e:
  113. self.logger.info("加载加强训练模型权重失败:{}".format(e.args))
  114. return False
  115. def training(self, model, train_and_valid_data):
  116. model.summary()
  117. train_x, train_y, valid_x, valid_y = train_and_valid_data
  118. early_stop = EarlyStopping(monitor='val_loss', patience=self.opt.Model['patience'], mode='auto')
  119. history = model.fit(train_x, train_y, batch_size=self.opt.Model['batch_size'], epochs=self.opt.Model['epoch'],
  120. verbose=2, validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
  121. loss = np.round(history.history['loss'], decimals=5)
  122. val_loss = np.round(history.history['val_loss'], decimals=5)
  123. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  124. self.logger.info("训练集损失函数为:{}".format(loss))
  125. self.logger.info("验证集损失函数为:{}".format(val_loss))
  126. return model
  127. def predict(self, test_x, batch_size=1):
  128. result = self.model.predict(test_x, batch_size=batch_size)
  129. self.logger.info("执行预测方法")
  130. return result
  131. if __name__ == "__main__":
  132. run_code = 0