model_forest_cdq.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2024/10/23 13:04
  4. # file: forest.py
  5. # author: David
  6. # company: shenyang JY
  7. from pathlib import Path
  8. import numpy as np
  9. import pandas as pd
  10. from sklearn.model_selection import train_test_split, GridSearchCV
  11. from sklearn.ensemble import RandomForestRegressor
  12. import joblib
  13. import threading
  14. class Forest(object):
  15. lock = threading.Lock()
  16. gbdt1 = None
  17. gbdt2 = None
  18. gbdt3 = None
  19. gbdt4 = None
  20. def __init__(self, log, args):
  21. self.args = args
  22. self.logger = log
  23. self.opt = self.args.parse_args_and_yaml()
  24. self.hs = []
  25. for fea in ['C_REAL_VALUE'] + [self.opt.usable_power["env"]]:
  26. self.hs += [fea + str(i + 1) for i in range(self.opt.history_env_hours * 4)]
  27. self.set_errors()
  28. def training_model(self, opt):
  29. cols = opt.data['columns'] + self.hs + ['C_FP_VALUE']
  30. data_path = Path('./cache/data/')
  31. files = list(data_path.rglob("lgb_data*.csv"))
  32. for file in files:
  33. data = pd.read_csv(file)
  34. df_train_x, df_train_y = data[cols].values, data['error'].values
  35. # train_x, val_x, train_y, val_y = train_test_split(df_train_x, df_train_y, test_size=0.1, shuffle=False)
  36. # 敲定好一组参数
  37. params_grid = {
  38. 'n_estimators': [50, 100, 200, 500],
  39. 'max_features': [None, 'sqrt', 'log2'],
  40. 'max_depth': [4, 6, 8, 10],
  41. # 'criterion': ['squared_error', 'absolute_error']
  42. }
  43. rf = RandomForestRegressor(random_state=42)
  44. grid_search = GridSearchCV(estimator=rf, param_grid=params_grid, cv=5, n_jobs=None, verbose=2, scoring='neg_mean_squared_error')
  45. grid_search.fit(df_train_x, df_train_y)
  46. # 输出最佳参数和最佳得分
  47. self.logger.info(f"Best parameters found: {grid_search.best_params_}")
  48. print(f"Best cross-validation score: {-grid_search.best_score_}")
  49. # 使用最佳模型进行预测
  50. best_model = grid_search.best_estimator_
  51. # pred_y = best_model.predict(val_x)
  52. # 计算测试集上的均方误差
  53. # mse = mean_squared_error(val_y, pred_y)
  54. # print(f"Mean Squared Error on test set: {mse}")
  55. # 保存模型
  56. self.logger.info('保存模型...')
  57. joblib.dump(best_model, './var/lgb_model_{}.pkl'.format(str(file)[-5]))
  58. @classmethod
  59. def set_errors(cls):
  60. try:
  61. with cls.lock:
  62. cls.gbdt1 = joblib.load('./var/lgb_model_1.pkl')
  63. cls.gbdt2 = joblib.load('./var/lgb_model_2.pkl')
  64. cls.gbdt3 = joblib.load('./var/lgb_model_3.pkl')
  65. cls.gbdt4 = joblib.load('./var/lgb_model_4.pkl')
  66. except Exception as e:
  67. print("加载模型权重失败:{}".format(e.args))
  68. def predict_error_clock(self, hour, data, api=False):
  69. cols = self.opt.data['columns'] + self.hs + ['C_FP_VALUE']
  70. if hour == 1:
  71. gbdt = Forest.gbdt1
  72. elif hour == 2:
  73. gbdt = Forest.gbdt2
  74. self.logger.info("预测模型,地址:{}".format(id(gbdt)))
  75. elif hour == 3:
  76. gbdt = Forest.gbdt3
  77. else:
  78. gbdt = Forest.gbdt4
  79. if api:
  80. dq = data['C_FP_VALUE']
  81. eat_data = data[cols].values[np.newaxis, :]
  82. error_predict = gbdt.predict(eat_data)[0]
  83. dq_fix = round(dq + error_predict, 2)
  84. dq_fix = dq_fix if dq_fix > 0 else 0
  85. dq_fix = self.opt.cap if dq_fix > self.opt.cap else dq_fix
  86. else:
  87. dq = data['C_FP_VALUE'].values
  88. eat_data = data[cols].values
  89. error_predict = gbdt.predict(eat_data)
  90. dq_fix = dq + error_predict
  91. dq_fix[dq_fix < 0] = 0 # 如果出现负数,置为0
  92. dq_fix[dq_fix > self.opt.cap] = self.opt.cap # 出现大于实际装机量的数,置为实际装机量
  93. dq_fix = np.around(dq_fix, decimals=2)
  94. return dq_fix
  95. def predict_error(self, pre_data):
  96. dq_fixs, ctimes = [], []
  97. for point, data in pre_data.iterrows():
  98. if point < 4:
  99. hour = 1
  100. elif point < 8:
  101. hour = 2
  102. elif point < 12:
  103. hour = 3
  104. else:
  105. hour = 4
  106. dq_fix = self.predict_error_clock(hour, data, api=True)
  107. dq_fixs.append(dq_fix)
  108. return dq_fixs
  109. if __name__ == '__main__':
  110. from config import myargparse
  111. from logs import Log
  112. args = myargparse(discription="场站端配置", add_help=False)
  113. log = Log()