#!/usr/bin/env python # -*- coding: utf-8 -*- # time: 2024/10/23 13:04 # file: forest.py # author: David # company: shenyang JY from pathlib import Path import numpy as np import pandas as pd from sklearn.model_selection import train_test_split, GridSearchCV from sklearn.ensemble import RandomForestRegressor import joblib import threading class Forest(object): lock = threading.Lock() gbdt1 = None gbdt2 = None gbdt3 = None gbdt4 = None def __init__(self, log, args): self.args = args self.logger = log self.opt = self.args.parse_args_and_yaml() self.hs = [] for fea in ['C_REAL_VALUE'] + [self.opt.usable_power["env"]]: self.hs += [fea + str(i + 1) for i in range(self.opt.history_env_hours * 4)] self.set_errors() def training_model(self, opt): cols = opt.data['columns'] + self.hs + ['C_FP_VALUE'] data_path = Path('./cache/data/') files = list(data_path.rglob("lgb_data*.csv")) for file in files: data = pd.read_csv(file) df_train_x, df_train_y = data[cols].values, data['error'].values # train_x, val_x, train_y, val_y = train_test_split(df_train_x, df_train_y, test_size=0.1, shuffle=False) # 敲定好一组参数 params_grid = { 'n_estimators': [50, 100, 200, 500], 'max_features': [None, 'sqrt', 'log2'], 'max_depth': [4, 6, 8, 10], # 'criterion': ['squared_error', 'absolute_error'] } rf = RandomForestRegressor(random_state=42) grid_search = GridSearchCV(estimator=rf, param_grid=params_grid, cv=5, n_jobs=None, verbose=2, scoring='neg_mean_squared_error') grid_search.fit(df_train_x, df_train_y) # 输出最佳参数和最佳得分 self.logger.info(f"Best parameters found: {grid_search.best_params_}") print(f"Best cross-validation score: {-grid_search.best_score_}") # 使用最佳模型进行预测 best_model = grid_search.best_estimator_ # pred_y = best_model.predict(val_x) # 计算测试集上的均方误差 # mse = mean_squared_error(val_y, pred_y) # print(f"Mean Squared Error on test set: {mse}") # 保存模型 self.logger.info('保存模型...') joblib.dump(best_model, './var/lgb_model_{}.pkl'.format(str(file)[-5])) @classmethod def set_errors(cls): try: with cls.lock: cls.gbdt1 = joblib.load('./var/lgb_model_1.pkl') cls.gbdt2 = joblib.load('./var/lgb_model_2.pkl') cls.gbdt3 = joblib.load('./var/lgb_model_3.pkl') cls.gbdt4 = joblib.load('./var/lgb_model_4.pkl') except Exception as e: print("加载模型权重失败:{}".format(e.args)) def predict_error_clock(self, hour, data, api=False): cols = self.opt.data['columns'] + self.hs + ['C_FP_VALUE'] if hour == 1: gbdt = Forest.gbdt1 elif hour == 2: gbdt = Forest.gbdt2 self.logger.info("预测模型,地址:{}".format(id(gbdt))) elif hour == 3: gbdt = Forest.gbdt3 else: gbdt = Forest.gbdt4 if api: dq = data['C_FP_VALUE'] eat_data = data[cols].values[np.newaxis, :] error_predict = gbdt.predict(eat_data)[0] dq_fix = round(dq + error_predict, 2) dq_fix = dq_fix if dq_fix > 0 else 0 dq_fix = self.opt.cap if dq_fix > self.opt.cap else dq_fix else: dq = data['C_FP_VALUE'].values eat_data = data[cols].values error_predict = gbdt.predict(eat_data) dq_fix = dq + error_predict dq_fix[dq_fix < 0] = 0 # 如果出现负数,置为0 dq_fix[dq_fix > self.opt.cap] = self.opt.cap # 出现大于实际装机量的数,置为实际装机量 dq_fix = np.around(dq_fix, decimals=2) return dq_fix def predict_error(self, pre_data): dq_fixs, ctimes = [], [] for point, data in pre_data.iterrows(): if point < 4: hour = 1 elif point < 8: hour = 2 elif point < 12: hour = 3 else: hour = 4 dq_fix = self.predict_error_clock(hour, data, api=True) dq_fixs.append(dq_fix) return dq_fixs if __name__ == '__main__': from config import myargparse from logs import Log args = myargparse(discription="场站端配置", add_help=False) log = Log()