|
@@ -12,6 +12,8 @@ import os, threading
|
|
|
import argparse, yaml
|
|
|
|
|
|
|
|
|
+
|
|
|
+
|
|
|
class myargparse(argparse.ArgumentParser):
|
|
|
_save_lock = threading.Lock()
|
|
|
def __init__(self, description, add_help):
|
|
@@ -25,6 +27,14 @@ class myargparse(argparse.ArgumentParser):
|
|
|
metavar='FILE',
|
|
|
help='YAML config file specifying default arguments')
|
|
|
self.add_argument(
|
|
|
+ '-n',
|
|
|
+ '--neu_yaml',
|
|
|
+ default=
|
|
|
+ 'neu.yml',
|
|
|
+ type=str,
|
|
|
+ metavar='FILE',
|
|
|
+ help='YAML config file specifying default arguments')
|
|
|
+ self.add_argument(
|
|
|
'-i',
|
|
|
'--input_file',
|
|
|
type=str,
|
|
@@ -40,13 +50,19 @@ class myargparse(argparse.ArgumentParser):
|
|
|
def _parse_args_and_yaml(self):
|
|
|
base_parser = argparse.ArgumentParser(add_help=False)
|
|
|
base_parser.add_argument('-c', '--config_yaml', default='config.yml', type=str)
|
|
|
+ base_parser.add_argument('-n', '--neu_yaml', default='neu.yml', type=str)
|
|
|
given_configs, remaining = base_parser.parse_known_args()
|
|
|
- current_path = os.path.dirname(__file__)
|
|
|
+ current_path = os.path.dirname(os.path.dirname(__file__))
|
|
|
if given_configs.config_yaml:
|
|
|
- config_path = os.path.join(current_path, given_configs.config_yaml)
|
|
|
+ config_path = os.path.join(current_path, 'common', given_configs.config_yaml)
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
cfg = yaml.safe_load(f)
|
|
|
self.set_defaults(**cfg)
|
|
|
+ if given_configs.neu_yaml:
|
|
|
+ config_path = os.path.join(current_path, 'common', given_configs.neu_yaml)
|
|
|
+ with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
+ model_cfg = yaml.safe_load(f)
|
|
|
+ self.set_defaults(**model_cfg)
|
|
|
# defaults will have been overridden if config file specified.
|
|
|
# opt = self.parse_args(remaining)
|
|
|
opt = self.parse_args(remaining)
|