data_features.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2023/4/12 17:42
  4. # file: data_features.py
  5. # author: David
  6. # company: shenyang JY
  7. import pandas as pd
  8. import numpy as np
  9. np.random.seed(42)
  10. class DataFeatures(object):
  11. def __init__(self, log, args):
  12. self.logger = log
  13. self.args = args
  14. self.opt = self.args.parse_args_and_yaml()
  15. self.columns = list()
  16. def train_valid_split(self, datax, datay, valid_rate, shuffle):
  17. shuffle_index = np.random.permutation(len(datax))
  18. indexs = shuffle_index.tolist() if shuffle else np.arange(0, len(datax)).tolist()
  19. valid_size = int(len(datax)*valid_rate)
  20. valid_index = indexs[-valid_size:]
  21. train_index = indexs[:-valid_size]
  22. tx, vx, ty, vy = [], [], [], []
  23. for i, data in enumerate(zip(datax, datay)):
  24. if i in train_index:
  25. tx.append(data[0])
  26. ty.append(data[1])
  27. elif i in valid_index:
  28. vx.append(data[0])
  29. vy.append(data[1])
  30. return tx, vx, ty, vy
  31. def get_train_data(self, dfs, envir):
  32. num = 1
  33. train_x, valid_x, train_y, valid_y = [], [], [], []
  34. for i, df in enumerate(dfs, start=1):
  35. if len(df) < self.opt.Model["time_step"]:
  36. self.logger.info("特征处理-训练数据-不满足time_step +{}".format(num))
  37. num += 1
  38. continue
  39. datax, datay = self.get_data_features(df, envir, is_train=True)
  40. if len(datax) < 10:
  41. self.logger.info("特征处理-训练数据-无法进行最小分割 +{}".format(num))
  42. num += 1
  43. continue
  44. tx, vx, ty, vy = self.train_valid_split(datax, datay, valid_rate=self.opt.Model["valid_data_rate"], shuffle=self.opt.Model['shuffle_train_data'])
  45. train_x.extend(tx)
  46. valid_x.extend(vx)
  47. train_y.extend(ty)
  48. valid_y.extend(vy)
  49. train_y = np.concatenate([[y.iloc[:, 1].values for y in train_y]], axis=0)
  50. valid_y = np.concatenate([[y.iloc[:, 1].values for y in valid_y]], axis=0)
  51. train_x = [np.array([x[0].values for x in train_x]), np.array([x[1].values for x in train_x])]
  52. valid_x = [np.array([x[0].values for x in valid_x]), np.array([x[1].values for x in valid_x])]
  53. return train_x, valid_x, train_y, valid_y
  54. def get_test_data(self, dfs, envir):
  55. num = 0
  56. test_x, test_y, data_y = [], [], []
  57. for i, df in enumerate(dfs, start=1):
  58. if len(df) < self.opt.Model["time_step"]:
  59. self.logger.info("特征处理-测试数据-不满足time_step +{}".format(num))
  60. num += 1
  61. continue
  62. datax, datay = self.get_data_features(df, envir, is_train=False)
  63. test_x.extend(datax)
  64. test_y.extend(datay)
  65. data_y.extend(datay)
  66. test_x = [np.array([x[0].values for x in test_x]), np.array([x[1].values for x in test_x])]
  67. test_y = np.concatenate([[y.iloc[:, 1].values for y in test_y]], axis=0)
  68. return test_x, test_y, data_y
  69. def get_realtime_data(self, dfs, envir):
  70. test_x = []
  71. for i, df in enumerate(dfs, start=1):
  72. if len(df) < self.opt.Model["time_step"]:
  73. self.logger.info("特征处理-预测数据-不满足time_step")
  74. continue
  75. datax = self.get_realtime_data_features(df, envir)
  76. test_x.extend(datax)
  77. test_x = [np.array([x[0].values for x in test_x]), np.array([x[1].values for x in test_x])]
  78. return test_x
  79. def get_data_features(self, norm_data, envir, is_train): # 这段代码基于pandas方法的优化
  80. time_step = self.opt.Model["time_step"]
  81. feature_data = norm_data.reset_index(drop=True)
  82. time_step_loc = time_step - 1
  83. train_num = int(len(feature_data))
  84. label_features = ['C_TIME', 'C_REAL_VALUE'] if is_train is True else ['C_TIME', 'C_REAL_VALUE']
  85. nwp_cs = self.opt.nwp_columns.copy()
  86. if 'C_TIME' in nwp_cs:
  87. nwp_cs.pop(nwp_cs.index('C_TIME'))
  88. nwp = [feature_data.loc[i:i + time_step_loc, nwp_cs].reset_index(drop=True) for i in range(train_num - time_step + 1)] # 数据库字段 'C_T': 'C_WS170'
  89. labels = [feature_data.loc[i:i + time_step_loc, label_features].reset_index(drop=True) for i in range(train_num - time_step + 1)]
  90. features_x, features_y = [], []
  91. env_fill = envir[-self.opt.Model["his_points"]:]
  92. self.logger.info("匹配环境前,{}组 -> ".format(len(nwp)))
  93. for i, row in enumerate(zip(nwp, labels)):
  94. time_end = row[1]['C_TIME'][0]
  95. time_start = time_end - pd.DateOffset(1)
  96. row1 = envir[(envir.C_TIME < time_end) & (envir.C_TIME > time_start)][-self.opt.Model["his_points"]:]
  97. if len(row1) < self.opt.Model["his_points"]:
  98. if self.opt.Model['fusion']:
  99. row1 = env_fill
  100. self.logger.info("训练环境数据不足{}个点:{},用数据进行填充".format(self.opt.Model["his_points"], len(row1)))
  101. else:
  102. self.logger.info("训练环境数据不足{}个点:{},弃用".format(self.opt.Model["his_points"], len(row1)))
  103. continue
  104. row1 = row1.reset_index(drop=True).drop(['C_TIME'], axis=1)
  105. features_x.append([row1, row[0]])
  106. features_y.append(row[1])
  107. self.logger.info("匹配环境后,{}组".format(len(features_x)))
  108. return features_x, features_y
  109. def get_realtime_data_features(self, norm_data, envir): # 这段代码基于pandas方法的优化
  110. time_step = self.opt.Model["time_step"]
  111. feature_data = norm_data.reset_index(drop=True)
  112. time_step_loc = time_step - 1
  113. nwp_cs = self.opt.nwp_columns.copy()
  114. if 'C_TIME' in nwp_cs:
  115. nwp_cs.pop(nwp_cs.index('C_TIME'))
  116. nwp = [feature_data.loc[i:i + time_step_loc, nwp_cs].reset_index(drop=True) for i in range(1)] # 数据库字段 'C_T': 'C_WS170'
  117. features_x, features_y = [], []
  118. self.logger.info("匹配环境前,{}组 -> ".format(len(nwp)))
  119. for i, row in enumerate(nwp):
  120. row1 = envir[-self.opt.Model["his_points"]:]
  121. if len(row1) < self.opt.Model["his_points"]:
  122. self.logger.info("环境数据不足{}个点:{}".format(self.opt.Model["his_points"], len(row1)))
  123. continue
  124. row1 = row1.reset_index(drop=True).drop(['C_TIME'], axis=1)
  125. features_x.append([row1, row])
  126. self.logger.info("匹配环境后,{}组".format(len(features_x)))
  127. return features_x