David 1 ay önce
ebeveyn
işleme
329a72d4cc
1 değiştirilmiş dosya ile 30 ekleme ve 8 silme
  1. 30 8
      app/model/main.py

+ 30 - 8
app/model/main.py

@@ -47,15 +47,9 @@ def add_nwp(df_obj, df):
         df_obj[add_cols] = df_obj[add_cols].add(df, fill_value=0)
     return df_obj
 
-def main():
-    # ---------------------------- 解析参数 ----------------------------
-    # 解析参数,将固定参数和任务参数合并
-    opt = parser.parse_args_and_yaml()
-    config = opt.__dict__
-    # 打印参数
-    logger.info(f"输入文件目录: {opt.input_file}")
-
+def dq_train(opt):
     # ---------------------------- 配置计算资源和任务 ----------------------------
+    config = opt.__dict__
     # 初始化资源管理器
     rc = ResourceController(
         max_workers=opt.system['max_workers'],
@@ -126,5 +120,33 @@ def main():
     task_config['gpu_assignment'] = gpu_id
     task.region_task(task_config, data_nwps)
 
+def dq_predict(opt):
+    pass
+
+def cdq_train(opt):
+    pass
+
+def cdq_predict(opt):
+    pass
+
+def main():
+    # ---------------------------- 解析参数 ----------------------------
+    # 解析参数,将固定参数和任务参数合并
+    opt = parser.parse_args_and_yaml()
+    config = opt.__dict__
+    # 打印参数
+    logger.info(f"输入文件目录: {opt.input_file}")
+    if 'dqyc' in opt.input_file.lower():
+        if 'model' in opt.input_file.lower():
+            dq_predict(opt)
+        else:
+            dq_predict(opt)
+    else:
+        if 'model' in opt.input_file.lower():
+            cdq_train(opt)
+        else:
+            cdq_predict(opt)
+
+
 if __name__ == "__main__":
     main()