formula.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # time: 2024/9/23 15:27
  4. # file: formula.py
  5. # author: David
  6. # company: shenyang JY
  7. import numpy as np
  8. class Formulas(object):
  9. def __init__(self, opt):
  10. self.opt = opt
  11. def calculate_acc(self, label_data, predict_data):
  12. loss = np.sum((label_data - predict_data) ** 2) / len(label_data) # mse
  13. loss_sqrt = np.sqrt(loss) # rmse
  14. loss_acc = (1 - loss_sqrt / self.opt.cap) * 100
  15. return loss_acc
  16. def calculate_acc_south(self, label_data, predict_data):
  17. cap = 0.1 * self.opt.cap
  18. mask = (label_data < cap) & (predict_data < cap)
  19. label_data = label_data[~mask]
  20. predict_data = predict_data[~mask]
  21. diff = label_data - predict_data
  22. base = np.where(label_data < self.opt.cap * 0.2, self.opt.cap * 0.2, label_data)
  23. acc = np.sum((diff / base) ** 2) / len(diff)
  24. acc = (1 - np.sqrt(acc)) * 100
  25. return acc
  26. def calculate_acc_northwest(self, label_data, predict_data):
  27. cap = 0.03 * self.opt.cap
  28. mask = (label_data < cap) & (predict_data < cap)
  29. label_data = label_data[~mask]
  30. predict_data = predict_data[~mask]
  31. diff = np.abs(label_data - predict_data)
  32. base1 = label_data + predict_data + 1e-9
  33. base2 = np.sum(diff) + 1e-9
  34. acc = (1 - 2 * np.sum(np.abs(label_data / base1 - 0.5) * diff / base2)) * 100
  35. return acc
  36. def calculate_acc_northeast(self, label_data, predict_data):
  37. cap = 0.1 * self.opt.cap
  38. mask = (label_data < cap) & (predict_data < cap)
  39. label_data = label_data[~mask]
  40. predict_data = predict_data[~mask]
  41. diff = np.abs(predict_data - label_data)
  42. deviation = diff / np.abs(predict_data + 1e-9)
  43. acc = np.where(deviation >= 1, 1, deviation)
  44. acc = 1 - np.mean(acc)
  45. return acc
  46. if __name__ == '__main__':
  47. from config import myargparse
  48. args = myargparse(discription="场站端配置", add_help=False)
  49. opt = args.parse_args_and_yaml()
  50. formula = Formulas(opt)