tf_transformer.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_transformer.py
  4. # @Time :2025/5/08 14:03
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from tensorflow.keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, Reshape, Flatten, LayerNormalization, Dropout
  8. from tensorflow.keras.models import Model, load_model
  9. from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
  10. from tensorflow.keras import optimizers, regularizers
  11. from models_processing.model_tf.losses import region_loss
  12. import numpy as np
  13. from common.database_dml import *
  14. from models_processing.model_tf.settings import set_deterministic
  15. from threading import Lock
  16. import argparse
  17. model_lock = Lock()
  18. set_deterministic(42)
  19. class TransformerHandler(object):
  20. def __init__(self, logger, args):
  21. self.logger = logger
  22. self.opt = argparse.Namespace(**args)
  23. self.model = None
  24. self.model_params = None
  25. def get_model(self, args):
  26. """
  27. 单例模式+线程锁,防止在异步加载时引发线程安全
  28. """
  29. try:
  30. with model_lock:
  31. loss = region_loss(self.opt)
  32. self.model, self.model_params = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
  33. except Exception as e:
  34. self.logger.info("加载模型权重失败:{}".format(e.args))
  35. @staticmethod
  36. def get_transformer_model(opt, time_series=1):
  37. hidden_size = opt.Model.get('hidden_size', 64)
  38. num_heads = opt.Model.get('num_heads', 4)
  39. ff_dim = opt.Model.get('ff_dim', 128)
  40. l2_reg = regularizers.l2(opt.Model.get('lambda_value_2', 0.0))
  41. nwp_input = Input(shape=(opt.Model['time_step'] * time_series, opt.Model['input_size']), name='nwp')
  42. # 输入嵌入
  43. x = Conv1D(hidden_size, 1, kernel_regularizer=l2_reg)(nwp_input)
  44. # Transformer编码器层
  45. for _ in range(opt.Model.get('num_layers', 2)):
  46. # 多头自注意力
  47. x = tf.keras.layers.MultiHeadAttention(
  48. num_heads=num_heads, key_dim=hidden_size,
  49. kernel_regularizer=l2_reg
  50. )(x, x)
  51. x = LayerNormalization()(x)
  52. x = tf.keras.layers.Dropout(0.1)(x)
  53. # 前馈网络
  54. x = tf.keras.layers.Dense(ff_dim, activation='relu', kernel_regularizer=l2_reg)(x)
  55. x = tf.keras.layers.Dense(hidden_size, kernel_regularizer=l2_reg)(x)
  56. x = LayerNormalization()(x)
  57. x = tf.keras.layers.Dropout(0.1)(x)
  58. # 提取中间时间步
  59. # start_idx = (time_steps - output_steps) // 2
  60. # x = x[:, start_idx:start_idx + output_steps, :]
  61. # 输出层
  62. output = Dense(1, name='cdq_output')(x) # 或者使用所有时间步
  63. output = Flatten(name='flatten')(output)
  64. model = Model(nwp_input, output)
  65. # 编译模型
  66. adam = optimizers.Adam(
  67. learning_rate=opt.Model.get('learning_rate', 0.001),
  68. beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=True
  69. )
  70. loss = region_loss(opt)
  71. model.compile(loss=loss, optimizer=adam)
  72. return model
  73. def train_init(self):
  74. try:
  75. # 进行加强训练,支持修模
  76. loss = region_loss(self.opt)
  77. base_train_model, self.model_params = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
  78. base_train_model.summary()
  79. self.logger.info("已加载加强训练基础模型")
  80. return base_train_model
  81. except Exception as e:
  82. self.logger.info("加载加强训练模型权重失败:{}".format(e.args))
  83. return False
  84. def training(self, model, train_and_valid_data):
  85. model.summary()
  86. train_x, train_y, valid_x, valid_y = train_and_valid_data
  87. early_stop = EarlyStopping(monitor='val_loss', patience=self.opt.Model['patience'], mode='auto')
  88. history = model.fit(train_x, train_y, batch_size=self.opt.Model['batch_size'], epochs=self.opt.Model['epoch'],
  89. verbose=2, validation_data=(valid_x, valid_y), callbacks=[early_stop], shuffle=False)
  90. loss = np.round(history.history['loss'], decimals=5)
  91. val_loss = np.round(history.history['val_loss'], decimals=5)
  92. self.logger.info("-----模型训练经过{}轮迭代-----".format(len(loss)))
  93. self.logger.info("训练集损失函数为:{}".format(loss))
  94. self.logger.info("验证集损失函数为:{}".format(val_loss))
  95. return model
  96. def predict(self, test_x, batch_size=1):
  97. result = self.model.predict(test_x, batch_size=batch_size)
  98. self.logger.info("执行预测方法")
  99. return result
  100. if __name__ == "__main__":
  101. run_code = 0