main.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2023/3/2 10:28
  4. # file: config.py
  5. # author: David
  6. # company: shenyang JY
  7. """
  8. 模型调参及系统功能配置
  9. """
  10. from tqdm import tqdm
  11. import argparse, yaml
  12. import pandas as pd
  13. from pathlib import Path
  14. from functools import partial
  15. from concurrent.futures import ProcessPoolExecutor
  16. from app.common.logs import params, logger
  17. from app.model.config import myargparse
  18. from app.model.resource_manager import ResourceController
  19. from app.model.task_worker import station_task
  20. """"
  21. 调用思路
  22. xxxx 1. 从入口参数中获取IN OUT文件位置 xxxx
  23. 2. 按照训练和预测加载和解析数据
  24. 3. 对数据进行预处理
  25. 4. 执行训练,保存模型,输出状态
  26. 5. 执行预测,输出结果,输出状态
  27. """
  28. """
  29. 训练任务
  30. 1.将一个省份下的所有场站加入队列
  31. 2.队列中的每个场站是一个子任务,还有最终的区域级子任务
  32. """
  33. def clean_power(power, env, plant_id):
  34. env_power = pd.merge(env, power, on=params['col_time'])
  35. if 'HubSpeed' in env.columns.tolist():
  36. from app.common.limited_power_wind import LimitPower
  37. lp = LimitPower(logger, params, env_power)
  38. power = lp.clean_limited_power(plant_id, True)
  39. elif 'Irradiance' in env.columns.tolist():
  40. from app.common.limited_power_solar import LimitPower
  41. lp = LimitPower(logger, params, env_power)
  42. power = lp.clean_limited_power(plant_id, True)
  43. return power
  44. def input_file_handler(input_file: str, model_name: str):
  45. # DQYC:短期预测,qy:区域级
  46. if 'dqyc' in input_file.lower():
  47. station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h, env = material(input_file, True)
  48. cap = round(float(station_info['PlantCap'][0]), 2)
  49. plant_id = int(station_info['PlantID'][0])
  50. # 含有model,训练
  51. if 'model' in input_file.lower():
  52. if env is not None and params['clean_power']: # 进行限电清洗
  53. power = clean_power(power, env, plant_id)
  54. train_data = pd.merge(nwp_v_h, power, on=params['col_time'])
  55. if model_name == 'fmi':
  56. from app.model.tf_fmi_train import model_training
  57. elif model_name == 'cnn':
  58. from app.model.tf_cnn_train import model_training
  59. else:
  60. from app.model.tf_lstm_train import model_training
  61. model_training(train_data, input_file, cap)
  62. # 含有predict,预测
  63. else:
  64. logger.info("训练路径错误!")
  65. else:
  66. # 区域级预测:未完
  67. basic_area = material(input_file, False)
  68. def get_station_list(input_file):
  69. return [str(child) for child in Path(input_file).iterdir() if child.is_dir()]
  70. def main():
  71. # 创建解析器对象
  72. # parser = argparse.ArgumentParser(description="程序描述")
  73. parser = myargparse(description='算法', add_help=False)
  74. # 创建
  75. # 添加参数
  76. parser.add_argument("input_file", help="输入文件路径") # 第一个位置参数
  77. parser.add_argument("--model_name", default="lstm", help='选择短期模型') # 第二个位置参数
  78. # 解析参数
  79. opt = parser.parse_args_and_yaml()
  80. config = opt.__dict__
  81. # 使用参数
  82. print(f"文件: {opt.input_file}")
  83. input_file_handler(opt.input_file, opt.model_name)
  84. # ----------------------------
  85. # 初始化资源管理器
  86. rc = ResourceController(
  87. max_workers=opt.system['max_workers'],
  88. gpu_list=opt.system['gpu_devices']
  89. )
  90. # 生成任务列表
  91. all_stations = get_station_list(opt.input_file)
  92. task_func = partial(station_task, config=config)
  93. # 进度跟踪
  94. completed = 0
  95. with tqdm(total=len(all_stations)) as pbar:
  96. with ProcessPoolExecutor(max_workers=rc.cpu_cores) as executor:
  97. futures = []
  98. for sid in all_stations:
  99. # 动态分配GPU
  100. gpu_id = rc.get_gpu()
  101. task_config = config.copy()
  102. task_config['gpu_assignment'] = gpu_id
  103. # 提交任务
  104. future = executor.submit(task_func, sid, task_config)
  105. future.add_done_callback(
  106. lambda _: rc.release_gpu(task_config['gpu_assignment']))
  107. futures.append(future)
  108. # 处理完成情况
  109. for future in futures:
  110. result = future.result()
  111. if result['status'] == 'success':
  112. completed += 1
  113. pbar.update(1)
  114. pbar.set_postfix_str(f"Completed: {completed}/{len(all_stations)}")
  115. print(f"Final result: {completed} stations trained successfully")
  116. if __name__ == "__main__":
  117. main()