figure.py 4.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2023/3/20 15:19
  4. # file: figure.py
  5. # author: David
  6. # company: shenyang JY
  7. import sys
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. class Figure(object):
  11. def __init__(self, opt, logger, process):
  12. self.opt = opt
  13. self.ds = process
  14. self.logger = logger
  15. def get_16_points(self, results):
  16. # results为模型预测的一维数组,遍历,取每16个点的最后一个点
  17. preds = []
  18. for res in results:
  19. preds.append(res[-1])
  20. return np.array(preds)
  21. def draw(self, label_data, predict_norm_data, numbers):
  22. # label_data = origin_data.data[origin_data.train_num + origin_data.start_num_in_test : ,
  23. # config.label_in_feature_index]
  24. # dq_data = dq_data.reshape((-1, self.opt.output_size))
  25. predict_norm_data = self.get_16_points(predict_norm_data)
  26. label_data = self.get_16_points(label_data)
  27. label_data = label_data.reshape((-1, self.opt.output_size))
  28. # label_data 要进行反归一化
  29. label_data = label_data * self.ds.std[self.opt.label_in_feature_index] + \
  30. self.ds.mean[self.opt.label_in_feature_index]
  31. predict_data = predict_norm_data * self.ds.std[self.opt.label_in_feature_index] + \
  32. self.ds.mean[self.opt.label_in_feature_index] # 通过保存的均值和方差还原数据
  33. # dq_data = dq_data * self.ds.std[0] + self.ds.mean[0]
  34. # predict_data = predict_norm_data
  35. assert label_data.shape[0] == predict_data.shape[0], "The element number in origin and predicted data is different"
  36. label_name = [self.ds.tables_column_name[i] for i in self.opt.label_in_feature_index]
  37. label_column_num = len(self.opt.label_columns)
  38. # label 和 predict 是错开config.predict_day天的数据的
  39. # 下面是两种norm后的loss的计算方式,结果是一样的,可以简单手推一下
  40. # label_norm_data = origin_data.norm_data[origin_data.train_num + origin_data.start_num_in_test:,
  41. # config.label_in_feature_index]
  42. # loss_norm = np.mean((label_norm_data[config.predict_day:] - predict_norm_data[:-config.predict_day]) ** 2, axis=0)
  43. # logger.info("The mean squared error of stock {} is ".format(label_name) + str(loss_norm))
  44. loss = np.sum((label_data - predict_data) ** 2)/len(label_data) # mse
  45. # loss = np.mean((label_data - predict_data) ** 2, axis=0)
  46. loss_sqrt = np.sqrt(loss) # rmse
  47. loss_norm = 1 - loss_sqrt / self.opt.cap
  48. # loss_norm = loss/(ds.std[opt.label_in_feature_index] ** 2)
  49. self.logger.info("The mean squared error of power {} is ".format(label_name) + str(loss_norm))
  50. # loss1 = np.sum((label_data - dq_data) ** 2) / len(label_data) # mse
  51. # loss_sqrt1 = np.sqrt(loss1) # rmse
  52. # loss_norm1 = 1 - loss_sqrt1 / self.opt.cap
  53. # self.logger.info("The mean squared error1 of power {} is ".format(label_name) + str(loss_norm1))
  54. if self.opt.is_continuous_predict:
  55. # label_X = range(int((self.ds.data_num - self.ds.train_num - 32)))
  56. label_X = list(range(numbers))
  57. else:
  58. label_X = range(int((self.ds.data_num - self.ds.train_num - self.ds.start_num_in_test)/2))
  59. print("label_x = ", label_X)
  60. predict_X = [x for x in label_X]
  61. if not sys.platform.startswith('linux'): # 无桌面的Linux下无法输出,如果是有桌面的Linux,如Ubuntu,可去掉这一行
  62. for i in range(label_column_num):
  63. plt.figure(i+1) # 预测数据绘制
  64. plt.plot(label_X, label_data[:, i], label='label', color='b')
  65. plt.plot(predict_X, predict_data[:, i], label='predict', color='g')
  66. # plt.plot(predict_X, dq_data[:, i], label='dq', color='y')
  67. # plt.title("Predict actual {} power with {}".format(label_name[i], self.opt.used_frame))
  68. self.logger.info("The predicted power {} for the last {} point(s) is: ".format(label_name[i], self.opt.predict_points) +
  69. str(np.squeeze(predict_data[-self.opt.predict_points:, i])))
  70. if self.opt.do_figure_save:
  71. plt.savefig(self.opt.figure_save_path+"{}predict_{}_with_{}.png".format(self.opt.continue_flag, label_name[i], opt.used_frame))
  72. plt.show()