|
@@ -33,55 +33,12 @@ from app.model.task_worker import station_task
|
|
|
"""
|
|
|
|
|
|
|
|
|
-def clean_power(power, env, plant_id):
|
|
|
- env_power = pd.merge(env, power, on=params['col_time'])
|
|
|
- if 'HubSpeed' in env.columns.tolist():
|
|
|
- from app.common.limited_power_wind import LimitPower
|
|
|
- lp = LimitPower(logger, params, env_power)
|
|
|
- power = lp.clean_limited_power(plant_id, True)
|
|
|
- elif 'Irradiance' in env.columns.tolist():
|
|
|
- from app.common.limited_power_solar import LimitPower
|
|
|
- lp = LimitPower(logger, params, env_power)
|
|
|
- power = lp.clean_limited_power(plant_id, True)
|
|
|
- return power
|
|
|
-
|
|
|
-
|
|
|
-def input_file_handler(input_file: str, model_name: str):
|
|
|
- # DQYC:短期预测,qy:区域级
|
|
|
- if 'dqyc' in input_file.lower():
|
|
|
- station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h, env = material(input_file, True)
|
|
|
- cap = round(float(station_info['PlantCap'][0]), 2)
|
|
|
- plant_id = int(station_info['PlantID'][0])
|
|
|
- # 含有model,训练
|
|
|
- if 'model' in input_file.lower():
|
|
|
- if env is not None and params['clean_power']: # 进行限电清洗
|
|
|
- power = clean_power(power, env, plant_id)
|
|
|
- train_data = pd.merge(nwp_v_h, power, on=params['col_time'])
|
|
|
- if model_name == 'fmi':
|
|
|
- from app.model.tf_fmi_train import model_training
|
|
|
- elif model_name == 'cnn':
|
|
|
- from app.model.tf_cnn_train import model_training
|
|
|
- else:
|
|
|
- from app.model.tf_lstm_train import model_training
|
|
|
- model_training(train_data, input_file, cap)
|
|
|
- # 含有predict,预测
|
|
|
- else:
|
|
|
- logger.info("训练路径错误!")
|
|
|
- else:
|
|
|
- # 区域级预测:未完
|
|
|
- basic_area = material(input_file, False)
|
|
|
-
|
|
|
-def get_station_list(input_file):
|
|
|
- return [str(child) for child in Path(input_file).iterdir() if child.is_dir()]
|
|
|
-
|
|
|
def main():
|
|
|
+ # ---------------------------- 解析参数 ----------------------------
|
|
|
# 创建解析器对象
|
|
|
- # parser = argparse.ArgumentParser(description="程序描述")
|
|
|
parser = myargparse(description='算法', add_help=False)
|
|
|
- # 创建
|
|
|
# 添加参数
|
|
|
parser.add_argument("input_file", help="输入文件路径") # 第一个位置参数
|
|
|
-
|
|
|
parser.add_argument("--model_name", default="lstm", help='选择短期模型') # 第二个位置参数
|
|
|
# 解析参数
|
|
|
opt = parser.parse_args_and_yaml()
|
|
@@ -89,21 +46,18 @@ def main():
|
|
|
# 使用参数
|
|
|
print(f"文件: {opt.input_file}")
|
|
|
|
|
|
- input_file_handler(opt.input_file, opt.model_name)
|
|
|
-
|
|
|
-
|
|
|
- # ----------------------------
|
|
|
-
|
|
|
- # 初始化资源管理器
|
|
|
+ # ---------------------------- 配置计算资源和任务 ----------------------------
|
|
|
+ # 初始化资源管理器
|
|
|
rc = ResourceController(
|
|
|
max_workers=opt.system['max_workers'],
|
|
|
gpu_list=opt.system['gpu_devices']
|
|
|
)
|
|
|
|
|
|
# 生成任务列表
|
|
|
- all_stations = get_station_list(opt.input_file)
|
|
|
+ all_stations = [str(child) for child in Path(opt.input_file).iterdir() if child.is_dir()]
|
|
|
task_func = partial(station_task, config=config)
|
|
|
|
|
|
+ # ---------------------------- 监控任务 ----------------------------
|
|
|
# 进度跟踪
|
|
|
completed = 0
|
|
|
with tqdm(total=len(all_stations)) as pbar:
|