Explorar el Código

解决参数问题

David hace 1 mes
padre
commit
75b2ef8386

+ 16 - 0
app/logs/2025-05-20/south-forecast.2025-05-20.0.log

@@ -22,3 +22,19 @@
  - _tfmw_add_deprecation_warning
 2025-05-20 09:15:34,607 - module_wrapper.py - WARNING - From E:\compete\app\model\losses.py:10: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.
  - _tfmw_add_deprecation_warning
+2025-05-20 09:20:26,339 - module_wrapper.py - WARNING - From E:\compete\app\model\losses.py:10: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.
+ - _tfmw_add_deprecation_warning
+2025-05-20 09:22:23,198 - module_wrapper.py - WARNING - From E:\compete\app\model\losses.py:10: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.
+ - _tfmw_add_deprecation_warning
+2025-05-20 09:24:15,858 - module_wrapper.py - WARNING - From E:\compete\app\model\losses.py:10: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.
+ - _tfmw_add_deprecation_warning
+2025-05-20 09:26:14,099 - module_wrapper.py - WARNING - From E:\compete\app\model\losses.py:10: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.
+ - _tfmw_add_deprecation_warning
+2025-05-20 09:37:00,765 - module_wrapper.py - WARNING - From E:\compete\app\model\losses.py:10: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.
+ - _tfmw_add_deprecation_warning
+2025-05-20 09:47:30,600 - module_wrapper.py - WARNING - From E:\compete\app\model\losses.py:10: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.
+ - _tfmw_add_deprecation_warning
+2025-05-20 09:48:51,011 - module_wrapper.py - WARNING - From E:\compete\app\model\losses.py:10: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.
+ - _tfmw_add_deprecation_warning
+2025-05-20 09:52:46,611 - module_wrapper.py - WARNING - From E:\compete\app\model\losses.py:10: The name tf.set_random_seed is deprecated. Please use tf.compat.v1.set_random_seed instead.
+ - _tfmw_add_deprecation_warning

+ 18 - 3
app/model/config.py

@@ -24,16 +24,31 @@ class myargparse(argparse.ArgumentParser):
             type=str,
             metavar='FILE',
             help='YAML config file specifying default arguments')
-
+        self.add_argument(
+            '-i',
+            '--input_file',
+            type=str,
+            metavar='FILE',
+            help='训练预测数据路径')
+        self.add_argument(
+            '-m',
+            '--model_name',
+            type=str,
+            metavar='FILE',
+            help='模型选择')
 
     def _parse_args_and_yaml(self):
-        given_configs, remaining = self.parse_known_args()
+        base_parser = argparse.ArgumentParser(add_help=False)
+        base_parser.add_argument('-c', '--config_yaml', default='config.yml', type=str)
+        given_configs, remaining = base_parser.parse_known_args()
         current_path = os.path.dirname(__file__)
         if given_configs.config_yaml:
-            with open(current_path + '/' + given_configs.config_yaml, 'r', encoding='utf-8') as f:
+            config_path = os.path.join(current_path, given_configs.config_yaml)
+            with open(config_path, 'r', encoding='utf-8') as f:
                 cfg = yaml.safe_load(f)
                 self.set_defaults(**cfg)
         # defaults will have been overridden if config file specified.
+        # opt = self.parse_args(remaining)
         opt = self.parse_args(remaining)
         # Cache the args as a text string to save them in the output dir later
         opt_text = yaml.safe_dump(opt.__dict__, default_flow_style=False)

+ 4 - 1
app/model/config.yml

@@ -44,4 +44,7 @@ Model:
   time_step: 16
   train_data_fill: true
   use_cuda: false
-  valid_data_rate: 0.15
+  valid_data_rate: 0.15
+
+input_file: E:/compete/app/model/data/DQYC/qy/62/1002/2025-04-21/IN
+model_name: lstm

+ 1 - 4
app/model/main.py

@@ -17,7 +17,7 @@ from concurrent.futures import ProcessPoolExecutor
 from app.common.logs import params, logger
 from app.model.config import myargparse
 from app.model.resource_manager import ResourceController
-from app.model.task_worker import station_task
+# from app.model.task_worker import station_task
 """"
 调用思路
    xxxx 1. 从入口参数中获取IN OUT文件位置 xxxx
@@ -37,9 +37,6 @@ def main():
     # ---------------------------- 解析参数 ----------------------------
     # 创建解析器对象
     parser = myargparse(description='算法', add_help=False)
-    # 添加参数
-    parser.add_argument("input_file", help="输入文件路径")    # 第一个位置参数
-    parser.add_argument("--model_name", default="lstm", help='选择短期模型')    # 第二个位置参数
     # 解析参数
     opt = parser.parse_args_and_yaml()
     config = opt.__dict__