123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # time: 2023/3/2 10:28
- # file: config.py
- # author: David
- # company: shenyang JY
- """
- 模型调参及系统功能配置
- """
- from tqdm import tqdm
- import argparse, yaml
- import pandas as pd
- from pathlib import Path
- from functools import partial
- from concurrent.futures import ProcessPoolExecutor
- from app.common.logs import params, logger
- from app.model.config import myargparse
- from app.model.resource_manager import ResourceController
- from app.model.task_worker import station_task
- """"
- 调用思路
- xxxx 1. 从入口参数中获取IN OUT文件位置 xxxx
- 2. 按照训练和预测加载和解析数据
- 3. 对数据进行预处理
- 4. 执行训练,保存模型,输出状态
- 5. 执行预测,输出结果,输出状态
- """
- """
- 训练任务
- 1.将一个省份下的所有场站加入队列
- 2.队列中的每个场站是一个子任务,还有最终的区域级子任务
- """
- def clean_power(power, env, plant_id):
- env_power = pd.merge(env, power, on=params['col_time'])
- if 'HubSpeed' in env.columns.tolist():
- from app.common.limited_power_wind import LimitPower
- lp = LimitPower(logger, params, env_power)
- power = lp.clean_limited_power(plant_id, True)
- elif 'Irradiance' in env.columns.tolist():
- from app.common.limited_power_solar import LimitPower
- lp = LimitPower(logger, params, env_power)
- power = lp.clean_limited_power(plant_id, True)
- return power
- def input_file_handler(input_file: str, model_name: str):
- # DQYC:短期预测,qy:区域级
- if 'dqyc' in input_file.lower():
- station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h, env = material(input_file, True)
- cap = round(float(station_info['PlantCap'][0]), 2)
- plant_id = int(station_info['PlantID'][0])
- # 含有model,训练
- if 'model' in input_file.lower():
- if env is not None and params['clean_power']: # 进行限电清洗
- power = clean_power(power, env, plant_id)
- train_data = pd.merge(nwp_v_h, power, on=params['col_time'])
- if model_name == 'fmi':
- from app.model.tf_fmi_train import model_training
- elif model_name == 'cnn':
- from app.model.tf_cnn_train import model_training
- else:
- from app.model.tf_lstm_train import model_training
- model_training(train_data, input_file, cap)
- # 含有predict,预测
- else:
- logger.info("训练路径错误!")
- else:
- # 区域级预测:未完
- basic_area = material(input_file, False)
- def get_station_list(input_file):
- return [str(child) for child in Path(input_file).iterdir() if child.is_dir()]
- def main():
- # 创建解析器对象
- # parser = argparse.ArgumentParser(description="程序描述")
- parser = myargparse(description='算法', add_help=False)
- # 创建
- # 添加参数
- parser.add_argument("input_file", help="输入文件路径") # 第一个位置参数
- parser.add_argument("--model_name", default="lstm", help='选择短期模型') # 第二个位置参数
- # 解析参数
- opt = parser.parse_args_and_yaml()
- config = opt.__dict__
- # 使用参数
- print(f"文件: {opt.input_file}")
- input_file_handler(opt.input_file, opt.model_name)
- # ----------------------------
- # 初始化资源管理器
- rc = ResourceController(
- max_workers=opt.system['max_workers'],
- gpu_list=opt.system['gpu_devices']
- )
- # 生成任务列表
- all_stations = get_station_list(opt.input_file)
- task_func = partial(station_task, config=config)
- # 进度跟踪
- completed = 0
- with tqdm(total=len(all_stations)) as pbar:
- with ProcessPoolExecutor(max_workers=rc.cpu_cores) as executor:
- futures = []
- for sid in all_stations:
- # 动态分配GPU
- gpu_id = rc.get_gpu()
- task_config = config.copy()
- task_config['gpu_assignment'] = gpu_id
- # 提交任务
- future = executor.submit(task_func, sid, task_config)
- future.add_done_callback(
- lambda _: rc.release_gpu(task_config['gpu_assignment']))
- futures.append(future)
- # 处理完成情况
- for future in futures:
- result = future.result()
- if result['status'] == 'success':
- completed += 1
- pbar.update(1)
- pbar.set_postfix_str(f"Completed: {completed}/{len(all_stations)}")
- print(f"Final result: {completed} stations trained successfully")
- if __name__ == "__main__":
- main()
|