David 2 miesięcy temu
rodzic
commit
a1b57dc51b

+ 2 - 2
app/common/data_handler.py

@@ -13,9 +13,9 @@ from sklearn.preprocessing import MinMaxScaler
 from app.common.data_cleaning import *
 
 class DataHandler(object):
-    def __init__(self, logger, args):
+    def __init__(self, logger, params):
         self.logger = logger
-        self.opt = argparse.Namespace(**args)
+        self.opt = argparse.Namespace(**params)
 
     def get_train_data(self, dfs, col_time, target):
         train_x, valid_x, train_y, valid_y = [], [], [], []

+ 2 - 3
app/common/limited_power_solar.py

@@ -10,10 +10,9 @@ current_path = os.path.dirname(__file__)
 parent_path = os.path.dirname(current_path)
 
 class LimitPower(object):
-    def __init__(self, logger, args, weather_power):
+    def __init__(self, logger, params, weather_power):
         self.logger = logger
-        self.args = args
-        self.opt = argparse.Namespace(**args)
+        self.opt = argparse.Namespace(**params)
         self.weather_power = weather_power
 
     def segment_statis(self):

+ 2 - 2
app/common/limited_power_wind.py

@@ -9,9 +9,9 @@ current_path = os.path.dirname(__file__)
 parent_path = os.path.dirname(current_path)
 
 class LimitPower(object):
-    def __init__(self, logger, args, weather_power):
+    def __init__(self, logger, params, weather_power):
         self.logger = logger
-        self.opt = argparse.Namespace(**args)
+        self.opt = argparse.Namespace(**params)
         self.weather_power = weather_power
         self.step = self.opt.usable_power_w['step']
         self.segs = np.array([x * self.step for x in range(1, 50)])  # 对风速以50为间隔进行分段

+ 1 - 1
app/common/logs.py

@@ -16,7 +16,7 @@ import logging, logging.handlers, time, os, re
 from logging.handlers import BaseRotatingHandler
 current_dir = os.path.dirname(os.path.abspath(__file__))
 with open(os.path.join(current_dir, 'neu.yaml'), 'r', encoding='utf-8') as f:
-    args = yaml.safe_load(f)
+    params = yaml.safe_load(f)
 
 class DailyRotatingFileHandler(BaseRotatingHandler):
     """

+ 2 - 2
app/common/tf_cnn.py

@@ -19,9 +19,9 @@ model_lock = Lock()
 
 
 class CNNHandler(object):
-    def __init__(self, logger, args):
+    def __init__(self, logger, params):
         self.logger = logger
-        self.opt = argparse.Namespace(**args)
+        self.opt = argparse.Namespace(**params)
         self.model = None
         self.model_params = None
         self.mongoUtils = MongoUtils(logger)

+ 2 - 2
app/common/tf_fmi.py

@@ -20,9 +20,9 @@ model_lock = Lock()
 
 
 class FMIHandler(object):
-    def __init__(self, logger, args):
+    def __init__(self, logger, params):
         self.logger = logger
-        self.opt = argparse.Namespace(**args)
+        self.opt = argparse.Namespace(**params)
         self.model = None
         self.model_params = None
         self.mongoUtils = MongoUtils(logger)

+ 2 - 2
app/common/tf_lstm.py

@@ -19,9 +19,9 @@ model_lock = Lock()
 
 
 class TSHandler(object):
-    def __init__(self, logger, args):
+    def __init__(self, logger, params):
         self.logger = logger
-        self.opt = argparse.Namespace(**args)
+        self.opt = argparse.Namespace(**params)
         self.model = None
         self.model_params = None
         self.mongoUtils = MongoUtils(logger)

+ 13 - 12
app/model/main.py

@@ -11,7 +11,7 @@
 import argparse
 import pandas as pd
 from pathlib import Path
-from app.common.logs import args, logger
+from app.common.logs import params, logger
 
 """"
 调用思路
@@ -41,7 +41,7 @@ def material(input_file, isDq=True):
         nwp_v_h = pd.read_csv(input_file / '0' / nwp_v_h, sep=r'\s+', header=0)
         nwp_own = pd.read_csv(input_file / '1' / nwp_own, sep=r'\s+', header=0)
         nwp_own_h = pd.read_csv(input_file / '1' / nwp_own_h, sep=r'\s+', header=0)
-        if args['switch_nwp_owner']:
+        if params['switch_nwp_owner']:
             nwp_v, nwp_v_h = nwp_own, nwp_own_h
         # 如果是风电
         if plant_type == 0:
@@ -67,32 +67,32 @@ def material(input_file, isDq=True):
         return basic_area
 
 def clean_power(power, env, plant_id):
-    env_power = pd.merge(env, power, on=args['col_time'])
+    env_power = pd.merge(env, power, on=params['col_time'])
     if 'HubSpeed' in env.columns.tolist():
         from app.common.limited_power_wind import LimitPower
-        lp = LimitPower(logger, args, env_power)
+        lp = LimitPower(logger, params, env_power)
         power = lp.clean_limited_power(plant_id, True)
     elif 'Irradiance' in env.columns.tolist():
         from app.common.limited_power_solar import LimitPower
-        lp = LimitPower(logger, args, env_power)
+        lp = LimitPower(logger, params, env_power)
         power = lp.clean_limited_power(plant_id, True)
     return power
 
 
-def input_file_handler(input_file: str):
+def input_file_handler(input_file: str, model_name: str):
     # DQYC:短期预测,qy:区域级
     if 'dqyc' in input_file.lower():
         station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h, env = material(input_file, True)
-        cap = round(station_info['PlantCap'][0], 2)
+        cap = round(float(station_info['PlantCap'][0]), 2)
         plant_id = int(station_info['PlantID'][0])
         # 含有model,训练
         if 'model' in input_file.lower():
             if env is not None:     # 进行限电清洗
                 power = clean_power(power, env, plant_id)
-            train_data = pd.merge(nwp_v_h, power, on=args['col_time'])
-            if args['model_name'] == 'fmi':
+            train_data = pd.merge(nwp_v_h, power, on=params['col_time'])
+            if model_name == 'fmi':
                 from app.model.tf_fmi_train import model_training
-            elif args['model_name'] == 'cnn':
+            elif model_name == 'cnn':
                 from app.model.tf_cnn_train import model_training
             else:
                 from app.model.tf_lstm_train import model_training
@@ -109,14 +109,15 @@ def main():
     parser = argparse.ArgumentParser(description="程序描述")
     # 创建
     # 添加参数
-    parser.add_argument("input_file", help="输入文件路径")
+    parser.add_argument("input_file", help="输入文件路径")    # 第一个位置参数
 
+    parser.add_argument("model_name", default="cnn", help='选择短期模型')    # 第二个位置参数
     # 解析参数
     args = parser.parse_args()
 
     # 使用参数
     print(f"文件: {args.input_file}")
-    input_file_handler(args.input_file)
+    input_file_handler(args.input_file, args.model_name)
 
 
 if __name__ == "__main__":

+ 10 - 11
app/model/tf_cnn_train.py

@@ -8,8 +8,7 @@ import json, os
 import numpy as np
 import traceback
 import logging
-
-from app.common.logs import args
+from app.common.logs import params
 from app.common.data_handler import DataHandler, write_number_to_file
 import time
 from app.common.tf_cnn import CNNHandler
@@ -18,8 +17,8 @@ from app.common.logs import logger
 np.random.seed(42)  # NumPy随机种子
 # tf.set_random_seed(42)  # TensorFlow随机种子
 
-dh = DataHandler(logger, args)
-cnn = CNNHandler(logger, args)
+dh = DataHandler(logger, params)
+cnn = CNNHandler(logger, params)
 mgUtils = MongoUtils(logger)
 
 def model_training(train_data, input_file, cap):
@@ -58,16 +57,16 @@ def model_training(train_data, input_file, cap):
         # 更新算法状态:1. 启动成功
         write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite')
         # ------------ 组装模型数据 ------------
-        args['Model']['features'] = ','.join(dh.opt.features)
-        args.update({
-            'params': json.dumps(args),
+        params['Model']['features'] = ','.join(dh.opt.features)
+        params.update({
+            'params': json.dumps(params),
             'descr': f'南网竞赛-{farm_id}',
             'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
-            'model_table': args['model_table'] + farm_id,
-            'scaler_table': args['scaler_table'] + farm_id
+            'model_table': params['model_table'] + farm_id,
+            'scaler_table': params['scaler_table'] + farm_id
         })
-        mgUtils.insert_trained_model_into_mongo(ts_model, args)
-        mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
+        mgUtils.insert_trained_model_into_mongo(ts_model, params)
+        mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, params)
         # 更新算法状态:正常结束
         write_number_to_file(os.path.join(output_file, status_file), 2, 2)
     except Exception as e:

+ 12 - 10
app/model/tf_fmi_train.py

@@ -9,7 +9,9 @@ import numpy as np
 import traceback
 import logging
 
-from app.common.logs import args
+from joblib.testing import param
+
+from app.common.logs import params
 from app.common.data_handler import DataHandler, write_number_to_file
 import time
 from app.common.tf_fmi import FMIHandler
@@ -18,8 +20,8 @@ from app.common.logs import logger
 np.random.seed(42)  # NumPy随机种子
 # tf.set_random_seed(42)  # TensorFlow随机种子
 
-dh = DataHandler(logger, args)
-fmi = FMIHandler(logger, args)
+dh = DataHandler(logger, params)
+fmi = FMIHandler(logger, params)
 mgUtils = MongoUtils(logger)
 
 def model_training(train_data, input_file, cap):
@@ -58,16 +60,16 @@ def model_training(train_data, input_file, cap):
         # 更新算法状态:1. 启动成功
         write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite')
         # ------------ 组装模型数据 ------------
-        args['Model']['features'] = ','.join(dh.opt.features)
-        args.update({
-            'params': json.dumps(args),
+        params['Model']['features'] = ','.join(dh.opt.features)
+        params.update({
+            'params': json.dumps(params),
             'descr': f'南网竞赛-{farm_id}',
             'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
-            'model_table': args['model_table'] + farm_id,
-            'scaler_table': args['scaler_table'] + farm_id
+            'model_table': params['model_table'] + farm_id,
+            'scaler_table': params['scaler_table'] + farm_id
         })
-        mgUtils.insert_trained_model_into_mongo(ts_model, args)
-        mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
+        mgUtils.insert_trained_model_into_mongo(ts_model, params)
+        mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, params)
         # 更新算法状态:正常结束
         write_number_to_file(os.path.join(output_file, status_file), 2, 2)
     except Exception as e:

+ 12 - 10
app/model/tf_lstm_train.py

@@ -9,7 +9,9 @@ import numpy as np
 import traceback
 import logging
 
-from app.common.logs import args
+from joblib.testing import param
+
+from app.common.logs import params
 from app.common.data_handler import DataHandler, write_number_to_file
 import time
 from app.common.tf_lstm import TSHandler
@@ -18,8 +20,8 @@ from app.common.logs import logger
 np.random.seed(42)  # NumPy随机种子
 # tf.set_random_seed(42)  # TensorFlow随机种子
 
-dh = DataHandler(logger, args)
-ts = TSHandler(logger, args)
+dh = DataHandler(logger, params)
+ts = TSHandler(logger, params)
 mgUtils = MongoUtils(logger)
 
 def model_training(train_data, input_file, cap):
@@ -58,16 +60,16 @@ def model_training(train_data, input_file, cap):
         # 更新算法状态:1. 启动成功
         write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite')
         # ------------ 组装模型数据 ------------
-        args['Model']['features'] = ','.join(dh.opt.features)
-        args.update({
-            'params': json.dumps(args),
+        params['Model']['features'] = ','.join(dh.opt.features)
+        params.update({
+            'params': json.dumps(params),
             'descr': f'南网竞赛-{farm_id}',
             'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
-            'model_table': args['model_table'] + farm_id,
-            'scaler_table': args['scaler_table'] + farm_id
+            'model_table': params['model_table'] + farm_id,
+            'scaler_table': params['scaler_table'] + farm_id
         })
-        mgUtils.insert_trained_model_into_mongo(ts_model, args)
-        mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
+        mgUtils.insert_trained_model_into_mongo(ts_model, params)
+        mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, params)
         # 更新算法状态:正常结束
         write_number_to_file(os.path.join(output_file, status_file), 2, 2)
     except Exception as e:

+ 24 - 24
app/predict/main.py

@@ -11,7 +11,7 @@
 import argparse
 import pandas as pd
 from pathlib import Path
-from app.common.logs import logger, args
+from app.common.logs import logger, params
 
 """
 调用思路
@@ -31,46 +31,46 @@ def material(input_file, isDq=True):
     nwp_v, nwp_v_h = 'DQYC_IN_FORECAST_WEATHER.txt', 'DQYC_IN_FORECAST_WEATHER_H.txt'
     nwp_own, nwp_own_h = 'DQYC_IN_FORECAST_WEATHER_OWNER.txt', 'DQYC_IN_FORECAST_WEATHER_OWNER_H.txt',
     input_file = Path(input_file)
-    basic = pd.read_csv(input_file / basi, sep='\s+', header=0)
-    power = pd.read_csv(input_file / power, sep='\s+', header=0)
+    basic = pd.read_csv(input_file / basi, sep=r'\s+', header=0)
+    power = pd.read_csv(input_file / power, sep=r'\s+', header=0)
     plant_type = int(basic.loc[basic['PropertyID'].to_list().index(('PlantType')), 'Value'])
     if isDq:
-        nwp_v = pd.read_csv(input_file / '0' / nwp_v, sep='\s+', header=0)
-        nwp_v_h = pd.read_csv(input_file / '0' / nwp_v_h, sep='\s+', header=0)
-        nwp_own = pd.read_csv(input_file / '1' / nwp_own, sep='\s+', header=0)
-        nwp_own_h = pd.read_csv(input_file / '1' / nwp_own_h, sep='\s+', header=0)
-        if args['switch_nwp_owner']:
+        nwp_v = pd.read_csv(input_file / '0' / nwp_v, sep=r'\s+', header=0)
+        nwp_v_h = pd.read_csv(input_file / '0' / nwp_v_h, sep=r'\s+', header=0)
+        nwp_own = pd.read_csv(input_file / '1' / nwp_own, sep=r'\s+', header=0)
+        nwp_own_h = pd.read_csv(input_file / '1' / nwp_own_h, sep=r'\s+', header=0)
+        if params['switch_nwp_owner']:
             nwp_v, nwp_v_h = nwp_own, nwp_own_h
         # 如果是风电
         if plant_type == 0:
-            station_info = pd.read_csv(input_file / station_info_w, sep='\s+', header=0)
-            station_info_d = pd.read_csv(input_file / station_info_d_w, sep='\s+', header=0)
-            nwp = pd.read_csv(input_file / nwp_w, sep='\s+', header=0)
-            nwp_h = pd.read_csv(input_file / nwp_w_h, sep='\s+', header=0)
+            station_info = pd.read_csv(input_file / station_info_w, sep=r'\s+', header=0)
+            station_info_d = pd.read_csv(input_file / station_info_d_w, sep=r'\s+', header=0)
+            nwp = pd.read_csv(input_file / nwp_w, sep=r'\s+', header=0)
+            nwp_h = pd.read_csv(input_file / nwp_w_h, sep=r'\s+', header=0)
             return station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h
         # 如果是光伏
         elif plant_type == 1:
-            station_info = pd.read_csv(input_file / station_info_s, sep='\s+', header=0)
-            station_info_d = pd.read_csv(input_file / station_info_d_s, sep='\s+', header=0)
-            nwp = pd.read_csv(input_file / nwp_s, sep='\s+', header=0)
-            nwp_h = pd.read_csv(input_file / nwp_s_h, sep='\s+', header=0)
+            station_info = pd.read_csv(input_file / station_info_s, sep=r'\s+', header=0)
+            station_info_d = pd.read_csv(input_file / station_info_d_s, sep=r'\s+', header=0)
+            nwp = pd.read_csv(input_file / nwp_s, sep=r'\s+', header=0)
+            nwp_h = pd.read_csv(input_file / nwp_s_h, sep=r'\s+', header=0)
             return station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h
     else:
         # 区域级预测待定,可能需要遍历获取场站数据
-        basic_area = pd.read_csv(input_file / basi_area, sep='\s+', header=0)
+        basic_area = pd.read_csv(input_file / basi_area, sep=r'\s+', header=0)
         return basic_area
 
-def input_file_handler(input_file: str):
+def input_file_handler(input_file: str, model_name: str):
     # DQYC:短期预测,qy:区域级
     if 'dqyc' in input_file.lower():
         station_info, station_info_d, nwp, nwp_h, power, nwp_v, nwp_v_h = material(input_file, True)
-        cap = round(station_info['PlantCap'][0], 2)
+        cap = round(float(station_info['PlantCap'][0]), 2)
         # 含有predict,预测
         if 'predict' in input_file.lower():
             pre_data = nwp_v
-            if args['model_name'] == 'fmi':
+            if model_name == 'fmi':
                 from app.predict.tf_fmi_pre import model_prediction
-            elif args['model_name'] == 'cnn':
+            elif model_name == 'cnn':
                 from app.predict.tf_cnn_pre import model_prediction
             else:
                 from app.predict.tf_lstm_pre import model_prediction
@@ -89,14 +89,14 @@ def main():
     parser = argparse.ArgumentParser(description="程序描述")
     # 创建
     # 添加参数
-    parser.add_argument("input_file", help="输入文件路径")
-
+    parser.add_argument("input_file", help="输入文件路径")    # 第一个位置参数
+    parser.add_argument("model_name", default="cnn", help='选择短期模型')  # 第二个位置参数
     # 解析参数
     args = parser.parse_args()
 
     # 使用参数
     print(f"文件: {args.input_file}")
-    input_file_handler(args.input_file)
+    input_file_handler(args.input_file, args.model_name)
 
 
 if __name__ == "__main__":

+ 11 - 11
app/predict/tf_cnn_pre.py

@@ -14,15 +14,15 @@ import time, json
 
 model_lock = Lock()
 from itertools import chain
-from app.common.logs import logger, args
+from app.common.logs import logger, params
 from app.common.tf_cnn import CNNHandler
 from app.common.dbmg import MongoUtils
 
 np.random.seed(42)  # NumPy随机种子
 
 
-dh = DataHandler(logger, args)
-cnn = CNNHandler(logger, args)
+dh = DataHandler(logger, params)
+cnn = CNNHandler(logger, params)
 mgUtils = MongoUtils(logger)
 
 
@@ -37,11 +37,11 @@ def model_prediction(pre_data, input_file, cap):
     file = 'DQYC_OUT_PREDICT_POWER.txt'
     status_file = 'STATUS.TXT'
     try:
-        args['model_table'] += farm_id
-        args['scaler_table'] += farm_id
-        feature_scaler, target_scaler = mgUtils.get_scaler_model_from_mongo(args)
-        cnn.opt.cap = round(target_scaler.transform(np.array([[float(cap)]]))[0, 0], 2)
-        cnn.get_model(args)
+        params['model_table'] += farm_id
+        params['scaler_table'] += farm_id
+        feature_scaler, target_scaler = mgUtils.get_scaler_model_from_mongo(params)
+        cnn.opt.cap = round(target_scaler.transform(np.array([[cap]]))[0, 0], 2)
+        cnn.get_model(params)
         dh.opt.features = json.loads(cnn.model_params).get('Model').get('features', ','.join(cnn.opt.features)).split(',')
         scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler)
 
@@ -52,10 +52,10 @@ def model_prediction(pre_data, input_file, cap):
         res = list(chain.from_iterable(target_scaler.inverse_transform([cnn.predict(scaled_pre_x).flatten()])))
         pre_data['Power'] = res[:len(pre_data)]
         pre_data['PlantID'] = farm_id
-        pre_data = pre_data[['PlantID', args['col_time'], 'Power']]
+        pre_data = pre_data[['PlantID', params['col_time'], 'Power']]
 
         pre_data.loc[:, 'Power'] = pre_data['Power'].round(2)
-        pre_data.loc[pre_data['Power'] > args['cap'], 'Power'] = args['cap']
+        pre_data.loc[pre_data['Power'] > cap, 'Power'] = cap
         pre_data.loc[pre_data['Power'] < 0, 'Power'] = 0
         pre_data.to_csv(os.path.join(output_file, file), sep=' ', index=False)
         # 更新算法状态:正常结束
@@ -72,7 +72,7 @@ def model_prediction(pre_data, input_file, cap):
     end_time = time.time()
 
     result['success'] = success
-    result['args'] = args
+    result['args'] = params
     result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
     result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
     print("Program execution ends!")

+ 10 - 10
app/predict/tf_fmi_pre.py

@@ -14,15 +14,15 @@ import time, json
 
 model_lock = Lock()
 from itertools import chain
-from app.common.logs import logger, args
+from app.common.logs import logger, params
 from app.common.tf_fmi import FMIHandler
 from app.common.dbmg import MongoUtils
 
 np.random.seed(42)  # NumPy随机种子
 
 
-dh = DataHandler(logger, args)
-fmi = FMIHandler(logger, args)
+dh = DataHandler(logger, params)
+fmi = FMIHandler(logger, params)
 mgUtils = MongoUtils(logger)
 
 
@@ -37,11 +37,11 @@ def model_prediction(pre_data, input_file, cap):
     file = 'DQYC_OUT_PREDICT_POWER.txt'
     status_file = 'STATUS.TXT'
     try:
-        args['model_table'] += farm_id
-        args['scaler_table'] += farm_id
-        feature_scaler, target_scaler = mgUtils.get_scaler_model_from_mongo(args)
-        fmi.opt.cap = round(target_scaler.transform(np.array([[float(cap)]]))[0, 0], 2)
-        fmi.get_model(args)
+        params['model_table'] += farm_id
+        params['scaler_table'] += farm_id
+        feature_scaler, target_scaler = mgUtils.get_scaler_model_from_mongo(params)
+        fmi.opt.cap = round(target_scaler.transform(np.array([[cap]]))[0, 0], 2)
+        fmi.get_model(params)
         dh.opt.features = json.loads(fmi.model_params).get('Model').get('features', ','.join(fmi.opt.features)).split(',')
         scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler)
 
@@ -52,10 +52,10 @@ def model_prediction(pre_data, input_file, cap):
         res = list(chain.from_iterable(target_scaler.inverse_transform([fmi.predict(scaled_pre_x).flatten()])))
         pre_data['Power'] = res[:len(pre_data)]
         pre_data['PlantID'] = farm_id
-        pre_data = pre_data[['PlantID', args['col_time'], 'Power']]
+        pre_data = pre_data[['PlantID', params['col_time'], 'Power']]
 
         pre_data.loc[:, 'Power'] = pre_data['Power'].round(2)
-        pre_data.loc[pre_data['Power'] > args['cap'], 'Power'] = args['cap']
+        pre_data.loc[pre_data['Power'] > cap, 'Power'] = cap
         pre_data.loc[pre_data['Power'] < 0, 'Power'] = 0
         pre_data.to_csv(os.path.join(output_file, file), sep=' ', index=False)
         # 更新算法状态:正常结束

+ 10 - 10
app/predict/tf_lstm_pre.py

@@ -14,15 +14,15 @@ import time, json
 
 model_lock = Lock()
 from itertools import chain
-from app.common.logs import logger, args
+from app.common.logs import logger, params
 from app.common.tf_lstm import TSHandler
 from app.common.dbmg import MongoUtils
 
 np.random.seed(42)  # NumPy随机种子
 
 
-dh = DataHandler(logger, args)
-ts = TSHandler(logger, args)
+dh = DataHandler(logger, params)
+ts = TSHandler(logger, params)
 mgUtils = MongoUtils(logger)
 
 
@@ -37,11 +37,11 @@ def model_prediction(pre_data, input_file, cap):
     file = 'DQYC_OUT_PREDICT_POWER.txt'
     status_file = 'STATUS.TXT'
     try:
-        args['model_table'] += farm_id
-        args['scaler_table'] += farm_id
-        feature_scaler, target_scaler = mgUtils.get_scaler_model_from_mongo(args)
-        ts.opt.cap = round(target_scaler.transform(np.array([[float(cap)]]))[0, 0], 2)
-        ts.get_model(args)
+        params['model_table'] += farm_id
+        params['scaler_table'] += farm_id
+        feature_scaler, target_scaler = mgUtils.get_scaler_model_from_mongo(params)
+        ts.opt.cap = round(target_scaler.transform(np.array([[cap]]))[0, 0], 2)
+        ts.get_model(params)
         dh.opt.features = json.loads(ts.model_params).get('Model').get('features', ','.join(ts.opt.features)).split(',')
         scaled_pre_x, pre_data = dh.pre_data_handler(pre_data, feature_scaler)
 
@@ -52,10 +52,10 @@ def model_prediction(pre_data, input_file, cap):
         res = list(chain.from_iterable(target_scaler.inverse_transform([ts.predict(scaled_pre_x).flatten()])))
         pre_data['Power'] = res[:len(pre_data)]
         pre_data['PlantID'] = farm_id
-        pre_data = pre_data[['PlantID', args['col_time'], 'Power']]
+        pre_data = pre_data[['PlantID', params['col_time'], 'Power']]
 
         pre_data.loc[:, 'Power'] = pre_data['Power'].round(2)
-        pre_data.loc[pre_data['Power'] > args['cap'], 'Power'] = args['cap']
+        pre_data.loc[pre_data['Power'] > cap, 'Power'] = cap
         pre_data.loc[pre_data['Power'] < 0, 'Power'] = 0
         pre_data.to_csv(os.path.join(output_file, file), sep=' ', index=False)
         # 更新算法状态:正常结束