David 1 неделя назад
Родитель
Сommit
4e93853131

+ 4 - 7
models_processing/model_tf/tf_transformer.py

@@ -38,9 +38,6 @@ class TransformerHandler(object):
 
     @staticmethod
     def get_transformer_model(opt, time_series=1):
-        time_steps = 48
-        input_features = 21
-        output_steps = 16
         hidden_size = opt.Model.get('hidden_size', 64)
         num_heads = opt.Model.get('num_heads', 4)
         ff_dim = opt.Model.get('ff_dim', 128)
@@ -68,12 +65,12 @@ class TransformerHandler(object):
             x = tf.keras.layers.Dropout(0.1)(x)
 
         # 提取中间时间步
-        start_idx = (time_steps - output_steps) // 2
-        x = x[:, start_idx:start_idx + output_steps, :]
+        # start_idx = (time_steps - output_steps) // 2
+        # x = x[:, start_idx:start_idx + output_steps, :]
 
         # 输出层
-        output = Dense(output_steps, name='cdq_output')(x[:, -1, :])  # 或者使用所有时间步
-
+        output = Dense(1, name='cdq_output')(x)  # 或者使用所有时间步
+        output = Flatten(name='flatten')(output)
         model = Model(nwp_input, output)
 
         # 编译模型

+ 4 - 4
models_processing/model_tf/tf_transformer_pre.py

@@ -26,7 +26,7 @@ np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm_pre——service')
 
 current_dir = os.path.dirname(os.path.abspath(__file__))
-with open(os.path.join(current_dir, 'lstm.yaml'), 'r', encoding='utf-8') as f:
+with open(os.path.join(current_dir, 'transformer.yaml'), 'r', encoding='utf-8') as f:
     global_config = yaml.safe_load(f)  # 只读的全局配置
 
 @app.before_request
@@ -45,7 +45,7 @@ def update_config():
     g.dh = DataHandler(logger, current_config)  # 每个请求独立实例
     g.trans = TransformerHandler(logger, current_config)
 
-@app.route('/tf_lstm_predict', methods=['POST'])
+@app.route('/tf_transformer_predict', methods=['POST'])
 def model_prediction_lstm():
     # 获取程序开始时间
     start_time = time.time()
@@ -64,7 +64,7 @@ def model_prediction_lstm():
         trans.opt.cap = round(target_scaler.transform(np.array([[float(args['cap'])]]))[0, 0], 2)
         trans.get_model(args)
         dh.opt.features = json.loads(trans.model_params)['Model']['features'].split(',')
-        scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler, time_series=args['time_series'], lstm_type=1)
+        scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler, time_series=args['time_series'])
         res = list(chain.from_iterable(target_scaler.inverse_transform(trans.predict(scaled_pre_x))))
         pre_data['farm_id'] = args.get('farm_id', 'null')
         if int(args.get('algorithm_test', 0)):
@@ -105,7 +105,7 @@ def model_prediction_lstm():
 if __name__ == "__main__":
     print("Program starts execution!")
     from waitress import serve
-    serve(app, host="0.0.0.0", port=10114,
+    serve(app, host="0.0.0.0", port=10132,
           threads=8,  # 指定线程数(默认4,根据硬件调整)
           channel_timeout=600  # 连接超时时间(秒)
           )

+ 6 - 6
models_processing/model_tf/tf_transformer_train.py

@@ -21,7 +21,7 @@ np.random.seed(42)  # NumPy随机种子
 app = Flask('tf_lstm_train——service')
 
 current_dir = os.path.dirname(os.path.abspath(__file__))
-with open(os.path.join(current_dir, 'lstm.yaml'), 'r', encoding='utf-8') as f:
+with open(os.path.join(current_dir, 'transformer.yaml'), 'r', encoding='utf-8') as f:
     global_config = yaml.safe_load(f)  # 只读的全局配置
 
 @app.before_request
@@ -41,7 +41,7 @@ def update_config():
     g.trans = TransformerHandler(logger, current_config)
 
 
-@app.route('/tf_lstm_training', methods=['POST'])
+@app.route('/tf_transformer_training', methods=['POST'])
 def model_training_lstm():
     # 获取程序开始时间
     start_time = time.time()
@@ -60,7 +60,7 @@ def model_training_lstm():
         # ------------ 训练模型,保存模型 ------------
         # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
         # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
-        model = trans.train_init() if trans.opt.Model['add_train'] else trans.get_transformer_model(trans.opt, time_series=args['time_series'], lstm_type=1)
+        model = trans.train_init() if trans.opt.Model['add_train'] else trans.get_transformer_model(trans.opt, time_series=args['time_series'])
         if trans.opt.Model['add_train']:
             if model:
                 feas = json.loads(trans.model_params)['features']
@@ -68,10 +68,10 @@ def model_training_lstm():
                     dh.opt.features = list(feas)
                     train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = dh.train_data_handler(train_data, time_series=args['time_series'])
                 else:
-                    model = trans.get_transformer_model(trans.opt, time_series=args['time_series'], lstm_type=1)
+                    model = trans.get_transformer_model(trans.opt, time_series=args['time_series'])
                     logger.info("训练数据特征,不满足,加强训练模型特征")
             else:
-                model = trans.get_transformer_model(trans.opt, time_series=args['time_series'], lstm_type=1)
+                model = trans.get_transformer_model(trans.opt, time_series=args['time_series'])
         ts_model = trans.training(model, [train_x, train_y, valid_x, valid_y])
         args['Model']['features'] = ','.join(dh.opt.features)
         args['params'] = json.dumps(args)
@@ -97,7 +97,7 @@ def model_training_lstm():
 if __name__ == "__main__":
     print("Program starts execution!")
     from waitress import serve
-    serve(app, host="0.0.0.0", port=10115,
+    serve(app, host="0.0.0.0", port=10131,
           threads=8,  # 指定线程数(默认4,根据硬件调整)
           channel_timeout=600  # 连接超时时间(秒)
           )

+ 93 - 0
models_processing/model_tf/transformer.yaml

@@ -0,0 +1,93 @@
+Model:
+  add_train: false
+  batch_size: 64
+  dropout_rate: 0.2
+  epoch: 500
+  fusion: true
+  hidden_size: 64
+  his_points: 16
+  how_long_fill: 10
+  input_size: 24
+  lambda_value_1: 0.02
+  lambda_value_2: 0.01
+  learning_rate: 0.001
+  lstm_layers: 1
+  output_size: 16
+  patience: 10
+  predict_data_fill: true
+  shuffle_train_data: false
+  test_data_fill: false
+  time_step: 16
+  train_data_fill: true
+  use_cuda: false
+  valid_data_rate: 0.15
+use_bidirectional: false
+region: south
+calculate: []
+cap: 153.0
+dataloc: ./data
+env_columns:
+- C_TIME
+- C_CELLT
+- C_DIFFUSER
+- C_GLOBALR
+- C_RH
+- C_REAL_VALUE
+full_field: true
+history_hours: 1
+new_field: true
+features:
+- time
+- temperature10
+- temperature190
+- direction160
+- direction40
+- temperature110
+- direction80
+- speed60
+- mcc
+- temperature150
+- speed20
+- speed110
+- direction120
+- speed190
+- solarZenith
+- temperature90
+- direction200
+- speed150
+- temperature50
+- direction30
+- temperature160
+- direction170
+- temperature20
+- direction70
+- direction130
+- temperature200
+- speed70
+- temperature120
+- speed30
+- speed100
+- speed80
+- speed180
+- dniCalcd
+- speed140
+- temperature60
+- dateTime
+- temperature30
+- temperature170
+- direction20
+- humidity2
+- direction180
+- realPowerAvg
+- direction60
+- direction140
+- speed40
+- hcc
+target: realPower
+repair_days: 81
+repair_model_cycle: 5
+spot_trading: []
+update_add_train_days: 60
+update_coe_days: 3
+version: solar-3.1.0.south
+