|
@@ -47,15 +47,9 @@ def add_nwp(df_obj, df):
|
|
|
df_obj[add_cols] = df_obj[add_cols].add(df, fill_value=0)
|
|
|
return df_obj
|
|
|
|
|
|
-def main():
|
|
|
- # ---------------------------- 解析参数 ----------------------------
|
|
|
- # 解析参数,将固定参数和任务参数合并
|
|
|
- opt = parser.parse_args_and_yaml()
|
|
|
- config = opt.__dict__
|
|
|
- # 打印参数
|
|
|
- logger.info(f"输入文件目录: {opt.input_file}")
|
|
|
-
|
|
|
+def dq_train(opt):
|
|
|
# ---------------------------- 配置计算资源和任务 ----------------------------
|
|
|
+ config = opt.__dict__
|
|
|
# 初始化资源管理器
|
|
|
rc = ResourceController(
|
|
|
max_workers=opt.system['max_workers'],
|
|
@@ -126,5 +120,33 @@ def main():
|
|
|
task_config['gpu_assignment'] = gpu_id
|
|
|
task.region_task(task_config, data_nwps)
|
|
|
|
|
|
+def dq_predict(opt):
|
|
|
+ pass
|
|
|
+
|
|
|
+def cdq_train(opt):
|
|
|
+ pass
|
|
|
+
|
|
|
+def cdq_predict(opt):
|
|
|
+ pass
|
|
|
+
|
|
|
+def main():
|
|
|
+ # ---------------------------- 解析参数 ----------------------------
|
|
|
+ # 解析参数,将固定参数和任务参数合并
|
|
|
+ opt = parser.parse_args_and_yaml()
|
|
|
+ config = opt.__dict__
|
|
|
+ # 打印参数
|
|
|
+ logger.info(f"输入文件目录: {opt.input_file}")
|
|
|
+ if 'dqyc' in opt.input_file.lower():
|
|
|
+ if 'model' in opt.input_file.lower():
|
|
|
+ dq_predict(opt)
|
|
|
+ else:
|
|
|
+ dq_predict(opt)
|
|
|
+ else:
|
|
|
+ if 'model' in opt.input_file.lower():
|
|
|
+ cdq_train(opt)
|
|
|
+ else:
|
|
|
+ cdq_predict(opt)
|
|
|
+
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
main()
|