刘大为 3 hete
szülő
commit
f58ba241d5
3 módosított fájl, 209 hozzáadás és 74 törlés
  1. 39 50
      app/model/main.py
  2. 124 0
      app/model/main410.py
  3. 46 24
      app/predict/main.py

+ 39 - 50
app/model/main.py

@@ -1,14 +1,7 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# time: 2023/3/2 10:28
-# file: config.py
-# author: David
-# company: shenyang JY
-
-"""
-模型调参及系统功能配置
-"""
 import argparse
+import os
+import uuid
+from celery import Celery
 import pandas as pd
 from pathlib import Path
 from app.common.logs import params, logger
@@ -22,15 +15,39 @@ from app.common.logs import params, logger
         5. 执行预测,输出结果,输出状态
 """
 
+# ------------------- Celery任务配置 -------------------
+celery_app = Celery(
+    'power_tasks',
+    broker='redis://redis:6379/0',  # 使用Redis作为消息队列
+    backend='redis://redis:6379/1',  # 任务结果存储
+    task_serializer='pickle',
+    result_serializer='pickle',
+    accept_content=['pickle']
+)
+
+# 动态控制并发数(最大不超过CPU核数)
+celery_app.conf.worker_concurrency = 4
+celery_app.conf.worker_prefetch_multiplier = 1  # 防止任务堆积
+
+""""
+调用思路
+   xxxx 1. 从入口参数中获取IN OUT文件位置 xxxx
+        2. 按照训练和预测加载和解析数据
+        3. 对数据进行预处理
+        4. 执行训练,保存模型,输出状态
+        5. 执行预测,输出结果,输出状态
+"""
+
+
 def material(input_file, isDq=True):
     basi, station_info_w, station_info_d_w, station_info_s, station_info_d_s, nwp_w, nwp_s, nwp_w_h, nwp_s_h, power = (
         'DQYC_IN_BASIC.txt', 'DQYC_IN_PLANT_WIND.txt', 'DQYC_IN_PLANT_DETAIL_WIND.txt', 'DQYC_IN_PLANT_SOLAR.txt',
         'DQYC_IN_PLANT_DETAIL_SOLAR.txt', 'DQYC_IN_FORECAST_WEATHER_WIND.txt', 'DQYC_IN_FORECAST_WEATHER_SOLAR.txt',
         'DQYC_IN_FORECAST_WEATHER_WIND_H.txt', 'DQYC_IN_FORECAST_WEATHER_SOLAR_H.txt', 'DQYC_IN_HISTORY_POWER_LONG.txt')
     basi_area = 'DQYC_AREA_IN_BASIC'
-    nwp_v, nwp_v_h = 'DQYC_IN_FORECAST_WEATHER.txt', 'DQYC_IN_FORECAST_WEATHER_H.txt' # 多版本气象
-    nwp_own, nwp_own_h = 'DQYC_IN_FORECAST_WEATHER_OWNER.txt', 'DQYC_IN_FORECAST_WEATHER_OWNER_H.txt', # 自有气象
-    env_wf, env_sf = 'DQYC_IN_ACTUAL_WEATHER_WIND', 'DQYC_IN_ACTUAL_WEATHER_SOLAR' # 实测气象
+    nwp_v, nwp_v_h = 'DQYC_IN_FORECAST_WEATHER.txt', 'DQYC_IN_FORECAST_WEATHER_H.txt'  # 多版本气象
+    nwp_own, nwp_own_h = 'DQYC_IN_FORECAST_WEATHER_OWNER.txt', 'DQYC_IN_FORECAST_WEATHER_OWNER_H.txt',  # 自有气象
+    env_wf, env_sf = 'DQYC_IN_ACTUAL_WEATHER_WIND', 'DQYC_IN_ACTUAL_WEATHER_SOLAR'  # 实测气象
     input_file = Path(input_file)
     env_w, env_s = None, None
     basic = pd.read_csv(input_file / basi, sep=r'\s+', header=0)
@@ -66,6 +83,7 @@ def material(input_file, isDq=True):
         basic_area = pd.read_csv(input_file / basi_area, sep=r'\s+', header=0)
         return basic_area
 
+
 def clean_power(power, env, plant_id):
     env_power = pd.merge(env, power, on=params['col_time'])
     if 'HubSpeed' in env.columns.tolist():
@@ -79,46 +97,17 @@ def clean_power(power, env, plant_id):
     return power
 
 
-def input_file_handler(input_file: str, model_name: str):
-    # DQYC:短期预测,qy:区域级
-    if 'dqyc' in input_file.lower():
-        station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h, env = material(input_file, True)
-        cap = round(float(station_info['PlantCap'][0]), 2)
-        plant_id = int(station_info['PlantID'][0])
-        # 含有model,训练
-        if 'model' in input_file.lower():
-            if env is not None and params['clean_power']:     # 进行限电清洗
-                power = clean_power(power, env, plant_id)
-            train_data = pd.merge(nwp_v_h, power, on=params['col_time'])
-            if model_name == 'fmi':
-                from app.model.tf_fmi_train import model_training
-            elif model_name == 'cnn':
-                from app.model.tf_cnn_train import model_training
-            else:
-                from app.model.tf_lstm_train import model_training
-            model_training(train_data, input_file, cap)
-        # 含有predict,预测
-        else:
-            logger.info("训练路径错误!")
-    else:
-        # 区域级预测:未完
-        basic_area = material(input_file, False)
-
 def main():
-    # 创建解析器对象
-    parser = argparse.ArgumentParser(description="程序描述")
-    # 创建
-    # 添加参数
-    parser.add_argument("input_file", help="输入文件路径")    # 第一个位置参数
-
-    parser.add_argument("--model_name", default="lstm", help='选择短期模型')    # 第二个位置参数
-    # 解析参数
+    """命令行入口(批量提交任务到队列)"""
+    parser = argparse.ArgumentParser(description="功率预测程序")
+    parser.add_argument("input_file", help="输入文件路径")
+    parser.add_argument("--model_name", default="lstm", help="选择短期模型")
     args = parser.parse_args()
 
-    # 使用参数
-    print(f"文件: {args.input_file}")
-    input_file_handler(args.input_file, args.model_name)
+    task_id = str(uuid.uuid4())
+    async_input_handler.delay(task_id, args.input_file, args.model_name)
+    print(f"训练任务已提交 | ID: {task_id} | 查看状态: /api/task/{task_id}")
 
 
 if __name__ == "__main__":
-    main()
+    main()

+ 124 - 0
app/model/main410.py

@@ -0,0 +1,124 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# time: 2023/3/2 10:28
+# file: config.py
+# author: David
+# company: shenyang JY
+
+"""
+模型调参及系统功能配置
+"""
+import argparse
+import pandas as pd
+from pathlib import Path
+from app.common.logs import params, logger
+
+""""
+调用思路
+   xxxx 1. 从入口参数中获取IN OUT文件位置 xxxx
+        2. 按照训练和预测加载和解析数据
+        3. 对数据进行预处理
+        4. 执行训练,保存模型,输出状态
+        5. 执行预测,输出结果,输出状态
+"""
+
+def material(input_file, isDq=True):
+    basi, station_info_w, station_info_d_w, station_info_s, station_info_d_s, nwp_w, nwp_s, nwp_w_h, nwp_s_h, power = (
+        'DQYC_IN_BASIC.txt', 'DQYC_IN_PLANT_WIND.txt', 'DQYC_IN_PLANT_DETAIL_WIND.txt', 'DQYC_IN_PLANT_SOLAR.txt',
+        'DQYC_IN_PLANT_DETAIL_SOLAR.txt', 'DQYC_IN_FORECAST_WEATHER_WIND.txt', 'DQYC_IN_FORECAST_WEATHER_SOLAR.txt',
+        'DQYC_IN_FORECAST_WEATHER_WIND_H.txt', 'DQYC_IN_FORECAST_WEATHER_SOLAR_H.txt', 'DQYC_IN_HISTORY_POWER_LONG.txt')
+    basi_area = 'DQYC_AREA_IN_BASIC'
+    nwp_v, nwp_v_h = 'DQYC_IN_FORECAST_WEATHER.txt', 'DQYC_IN_FORECAST_WEATHER_H.txt' # 多版本气象
+    nwp_own, nwp_own_h = 'DQYC_IN_FORECAST_WEATHER_OWNER.txt', 'DQYC_IN_FORECAST_WEATHER_OWNER_H.txt', # 自有气象
+    env_wf, env_sf = 'DQYC_IN_ACTUAL_WEATHER_WIND', 'DQYC_IN_ACTUAL_WEATHER_SOLAR' # 实测气象
+    input_file = Path(input_file)
+    env_w, env_s = None, None
+    basic = pd.read_csv(input_file / basi, sep=r'\s+', header=0)
+    power = pd.read_csv(input_file / power, sep=r'\s+', header=0)
+    plant_type = int(basic.loc[basic['PropertyID'].to_list().index(('PlantType')), 'Value'])
+    if isDq:
+        nwp_v = pd.read_csv(input_file / '0' / nwp_v, sep=r'\s+', header=0)
+        nwp_v_h = pd.read_csv(input_file / '0' / nwp_v_h, sep=r'\s+', header=0)
+        nwp_own = pd.read_csv(input_file / '1' / nwp_own, sep=r'\s+', header=0)
+        nwp_own_h = pd.read_csv(input_file / '1' / nwp_own_h, sep=r'\s+', header=0)
+        if params['switch_nwp_owner']:
+            nwp_v, nwp_v_h = nwp_own, nwp_own_h
+        # 如果是风电
+        if plant_type == 0:
+            station_info = pd.read_csv(input_file / station_info_w, sep=r'\s+', header=0)
+            station_info_d = pd.read_csv(input_file / station_info_d_w, sep=r'\s+', header=0)
+            nwp = pd.read_csv(input_file / nwp_w, sep=r'\s+', header=0)
+            nwp_h = pd.read_csv(input_file / nwp_w_h, sep=r'\s+', header=0)
+            if (input_file / env_wf).exists():
+                env_w = pd.read_csv(input_file / env_wf, sep=r'\s+', header=0)
+            return station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h, env_w
+        # 如果是光伏
+        elif plant_type == 1:
+            station_info = pd.read_csv(input_file / station_info_s, sep=r'\s+', header=0)
+            station_info_d = pd.read_csv(input_file / station_info_d_s, sep=r'\s+', header=0)
+            nwp = pd.read_csv(input_file / nwp_s, sep=r'\s+', header=0)
+            nwp_h = pd.read_csv(input_file / nwp_s_h, sep=r'\s+', header=0)
+            if (input_file / env_sf).exists():
+                env_s = pd.read_csv(input_file / env_sf, sep=r'\s+', header=0)
+            return station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h, env_s
+    else:
+        # 区域级预测待定,可能需要遍历获取场站数据
+        basic_area = pd.read_csv(input_file / basi_area, sep=r'\s+', header=0)
+        return basic_area
+
+def clean_power(power, env, plant_id):
+    env_power = pd.merge(env, power, on=params['col_time'])
+    if 'HubSpeed' in env.columns.tolist():
+        from app.common.limited_power_wind import LimitPower
+        lp = LimitPower(logger, params, env_power)
+        power = lp.clean_limited_power(plant_id, True)
+    elif 'Irradiance' in env.columns.tolist():
+        from app.common.limited_power_solar import LimitPower
+        lp = LimitPower(logger, params, env_power)
+        power = lp.clean_limited_power(plant_id, True)
+    return power
+
+
+def input_file_handler(input_file: str, model_name: str):
+    # DQYC:短期预测,qy:区域级
+    if 'dqyc' in input_file.lower():
+        station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h, env = material(input_file, True)
+        cap = round(float(station_info['PlantCap'][0]), 2)
+        plant_id = int(station_info['PlantID'][0])
+        # 含有model,训练
+        if 'model' in input_file.lower():
+            if env is not None and params['clean_power']:     # 进行限电清洗
+                power = clean_power(power, env, plant_id)
+            train_data = pd.merge(nwp_v_h, power, on=params['col_time'])
+            if model_name == 'fmi':
+                from app.model.tf_fmi_train import model_training
+            elif model_name == 'cnn':
+                from app.model.tf_cnn_train import model_training
+            else:
+                from app.model.tf_lstm_train import model_training
+            model_training(train_data, input_file, cap)
+        # 含有predict,预测
+        else:
+            logger.info("训练路径错误!")
+    else:
+        # 区域级预测:未完
+        basic_area = material(input_file, False)
+
+def main():
+    # 创建解析器对象
+    parser = argparse.ArgumentParser(description="程序描述")
+    # 创建
+    # 添加参数
+    parser.add_argument("input_file", help="输入文件路径")    # 第一个位置参数
+
+    parser.add_argument("--model_name", default="lstm", help='选择短期模型')    # 第二个位置参数
+    # 解析参数
+    args = parser.parse_args()
+
+    # 使用参数
+    print(f"文件: {args.input_file}")
+    input_file_handler(args.input_file, args.model_name)
+
+
+if __name__ == "__main__":
+    main()

+ 46 - 24
app/predict/main.py

@@ -8,8 +8,9 @@
 """
 模型调参及系统功能配置
 """
-import argparse
+import argparse, uuid
 import pandas as pd
+from celery import Celery
 from pathlib import Path
 from app.common.logs import logger, params
 
@@ -21,6 +22,20 @@ from app.common.logs import logger, params
         4. 执行训练,保存模型
         5. 执行预测,输出结果
 """
+# ------------------- Celery任务配置 -------------------
+celery_app = Celery(
+    'power_tasks',
+    broker='redis://redis:6379/0',  # 使用Redis作为消息队列
+    backend='redis://redis:6379/1',  # 任务结果存储
+    task_serializer='pickle',
+    result_serializer='pickle',
+    accept_content=['pickle']
+)
+
+# 动态控制并发数(最大不超过CPU核数)
+celery_app.conf.worker_concurrency = 4
+celery_app.conf.worker_prefetch_multiplier = 1  # 防止任务堆积
+
 
 def material(input_file, isDq=True):
     basi, station_info_w, station_info_d_w, station_info_s, station_info_d_s, nwp_w, nwp_s, nwp_w_h, nwp_s_h, power = (
@@ -60,28 +75,33 @@ def material(input_file, isDq=True):
         basic_area = pd.read_csv(input_file / basi_area, sep=r'\s+', header=0)
         return basic_area
 
-def input_file_handler(input_file: str, model_name: str):
-    # DQYC:短期预测,qy:区域级
-    if 'dqyc' in input_file.lower():
-        station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h = material(input_file, True)
-        cap = round(float(station_info['PlantCap'][0]), 2)
-        # 含有predict,预测
-        if 'predict' in input_file.lower():
-            pre_data = nwp_v
-            if model_name == 'fmi':
-                from app.predict.tf_fmi_pre import model_prediction
-            elif model_name == 'cnn':
-                from app.predict.tf_cnn_pre import model_prediction
+
+@celery_app.task(bind=True, max_retries=3, time_limit=600)
+def input_file_handler(self, task_id: str, input_file: str, model_name: str):
+    try:
+        # DQYC:短期预测,qy:区域级
+        if 'dqyc' in input_file.lower():
+            station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h = material(input_file, True)
+            cap = round(float(station_info['PlantCap'][0]), 2)
+            # 含有predict,预测
+            if 'predict' in input_file.lower():
+                pre_data = nwp_v
+                if model_name == 'fmi':
+                    from app.predict.tf_fmi_pre import model_prediction
+                elif model_name == 'cnn':
+                    from app.predict.tf_cnn_pre import model_prediction
+                else:
+                    from app.predict.tf_lstm_pre import model_prediction
+                model_prediction(pre_data, input_file, cap)
             else:
-                from app.predict.tf_lstm_pre import model_prediction
-            model_prediction(pre_data, input_file, cap)
+                logger.info("预测路径错误!")
         else:
-            logger.info("预测路径错误!")
-    else:
-        # 区域级预测:未完
-        # basic_area = material(input_file, False)
-        logger.info("区域级预测待开放。")
-
+            # 区域级预测:未完
+            # basic_area = material(input_file, False)
+            logger.info("区域级预测待开放。")
+        return {"status": "success", "task_id": task_id}
+    except Exception as e:
+        self.retry(exc=e, countdown=2 ** self.request.retries)
 
 
 def main():
@@ -89,15 +109,17 @@ def main():
     parser = argparse.ArgumentParser(description="程序描述")
     # 创建
     # 添加参数
-    parser.add_argument("input_file", help="输入文件路径")    # 第一个位置参数
+    parser.add_argument("input_file", help="输入文件路径")  # 第一个位置参数
     parser.add_argument("--model_name", default="cnn", help='选择短期模型')  # 第二个位置参数(可选)
     # 解析参数
     args = parser.parse_args()
 
     # 使用参数
     print(f"文件: {args.input_file}")
-    input_file_handler(args.input_file, args.model_name)
+    task_id = str(uuid.uuid4())
+    input_file_handler.delay(task_id, args.input_file, args.model_name)
+    print(f"任务已提交 | ID: {task_id} | 查看状态: /api/task/{task_id}")
 
 
 if __name__ == "__main__":
-    main()
+    main()