data_process.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2024/5/6 13:52
  4. # file: data_process.py
  5. # author: David
  6. # company: shenyang JY
  7. import os
  8. import numpy as np
  9. import pandas as pd
  10. from joblib.parallel import method
  11. from cache.data_cleaning import rm_duplicated
  12. np.random.seed(42)
  13. class DataProcess(object):
  14. def __init__(self, log, args):
  15. self.logger = log
  16. self.args = args
  17. self.opt = self.args.parse_args_and_yaml()
  18. # 主要是联立后的补值操作
  19. def get_train_data(self, unite, envir):
  20. # unite = pd.merge(unite, envir, on='C_TIME')
  21. # 第二步:计算间隔
  22. unite['C_TIME'] = pd.to_datetime(unite['C_TIME'])
  23. unite['time_diff'] = unite['C_TIME'].diff()
  24. dt_short = pd.Timedelta(minutes=15)
  25. dt_long = pd.Timedelta(minutes=15 * self.opt.Model['how_long_fill'])
  26. data_train = self.missing_time_splite(unite, dt_short, dt_long)
  27. miss_points = unite[(unite['time_diff'] > dt_short) & (unite['time_diff'] < dt_long)]
  28. miss_number = miss_points['time_diff'].dt.total_seconds().sum(axis=0)/(15*60) - len(miss_points)
  29. self.logger.info("再次测算,需要插值的总点数为:{}".format(miss_number))
  30. if miss_number > 0 and self.opt.Model["train_data_fill"]:
  31. data_train = self.data_fill(data_train)
  32. envir.set_index('C_TIME', inplace=True)
  33. envir = envir.interpolate()
  34. envir = envir.fillna('bfill')
  35. envir = envir.fillna('ffill')
  36. envir.reset_index(inplace=True, drop=False)
  37. return data_train, envir
  38. def get_test_data(self, unite, envir):
  39. unite['C_TIME'] = pd.to_datetime(unite['C_TIME'])
  40. unite['time_diff'] = unite['C_TIME'].diff()
  41. dt_short = pd.Timedelta(minutes=15)
  42. dt_long = pd.Timedelta(minutes=15 * self.opt.Model['how_long_fill'])
  43. data_test = self.missing_time_splite(unite, dt_short, dt_long)
  44. miss_points = unite[(unite['time_diff'] > dt_short) & (unite['time_diff'] < dt_long)]
  45. miss_number = miss_points['time_diff'].dt.total_seconds().sum(axis=0) / (15 * 60) - len(miss_points)
  46. self.logger.info("再次测算,需要插值的总点数为:{}".format(miss_number))
  47. if self.opt.Model["predict_data_fill"] and miss_number > 0:
  48. data_test = self.data_fill(data_test, test=True)
  49. return data_test, envir
  50. def get_predict_data(self, nwp, dq):
  51. if self.opt.Model["predict_data_fill"] and len(dq) > len(nwp):
  52. self.logger.info("接口nwp和dq合并清洗后,需要插值的总点数为:{}".format(len(dq)-len(nwp)))
  53. nwp.set_index('C_TIME', inplace=True)
  54. dq.set_index('C_TIME', inplace=True)
  55. nwp = nwp.resample('15T').interpolate(method='linear') # nwp先进行线性填充
  56. nwp = nwp.reindex(dq.index, method='bfill') # 再对超过采样边缘无法填充的点进行二次填充
  57. nwp = nwp.reindex(dq.index, method='ffill')
  58. nwp.reset_index(drop=False, inplace=True)
  59. dq.reset_index(drop=False, inplace=True)
  60. return nwp
  61. def missing_time_splite(self, df, dt_short, dt_long):
  62. n_long, n_short, n_points = 0, 0, 0
  63. start_index = 0
  64. dfs = []
  65. for i in range(1, len(df)):
  66. if df['time_diff'][i] >= dt_long:
  67. df_long = df.iloc[start_index:i, :-1]
  68. dfs.append(df_long)
  69. start_index = i
  70. n_long += 1
  71. if df['time_diff'][i] > dt_short:
  72. self.logger.info(f"{df['C_TIME'][i-1]} ~ {df['C_TIME'][i]}")
  73. points = round(df['time_diff'].dt.total_seconds()[i]/(60*15))-1
  74. self.logger.info("缺失点数:{}".format(points))
  75. if df['time_diff'][i] < dt_long:
  76. n_short += 1
  77. n_points += points
  78. print("需要补值的点数:", points)
  79. dfs.append(df.iloc[start_index:, :-1])
  80. self.logger.info(f"数据总数:{len(df)}, 时序缺失的间隔:{n_short}, 其中,较长的时间间隔:{n_long}")
  81. self.logger.info("需要补值的总点数:{}".format(n_points))
  82. return dfs
  83. def data_fill(self, dfs, test=False):
  84. dfs_fill, inserts = [], 0
  85. for i, df in enumerate(dfs):
  86. df = rm_duplicated(df)
  87. df1 = df.set_index('C_TIME', inplace=False)
  88. dff = df1.resample('15T').interpolate(method='linear') # 采用线性补值,其他补值方法需要进一步对比
  89. dff.reset_index(inplace=True)
  90. points = len(dff) - len(df1)
  91. dfs_fill.append(dff)
  92. self.logger.info(
  93. "{} ~ {} 有 {} 个点, 填补 {} 个点.".format(dff.iloc[0, 0], dff.iloc[-1, 0], len(dff), points))
  94. inserts += points
  95. name = "预测数据" if test is True else "训练集"
  96. self.logger.info("{}分成了{}段,实际一共补值{}点".format(name, len(dfs_fill), inserts))
  97. return dfs_fill