task_worker.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :task_worker.py
  4. # @Time :2025/4/29 11:05
  5. # @Author :David
  6. # @Company: shenyang JY
  7. import logging
  8. import pandas as pd
  9. from app.model.tf_model_train import ModelTrainer
  10. from app.predict.tf_model_pre import ModelPre
  11. from app.model.tf_region_train import RegionTrainer
  12. from app.model.material import MaterialLoader
  13. class TaskTrain(object):
  14. def __init__(self, loader):
  15. self.loader = loader
  16. def station_task(self, config):
  17. """场站级训练任务"""
  18. station_id = -99
  19. try:
  20. print("111")
  21. station_id = config['station_id']
  22. # 动态生成场站数据路径
  23. print("222")
  24. # 加载数据
  25. data_objects = self.loader.get_material(station_id)
  26. local_weights = self.loader.add_weights(data_objects)
  27. print("333")
  28. # 数据合并
  29. train_data = pd.merge(data_objects.nwp_v_h, data_objects.power, on=config['col_time'])
  30. print("444")
  31. # 模型训练
  32. # model = ModelTrainer(station_id, train_data, capacity=data_objects.cap, gpu_id=config.get('gpu_assignment'))
  33. model = ModelTrainer(train_data, capacity=data_objects.cap, config=config)
  34. model.train()
  35. print("555")
  36. return {'status': 'success', 'station_id': station_id, 'weights': local_weights}
  37. except Exception as e:
  38. logging.error(f"Station {station_id} failed: {str(e)}")
  39. return {'status': 'failed', 'station_id': station_id}
  40. def region_task(self, config, data_nwps):
  41. """区域级训练任务"""
  42. area_id = -99
  43. try:
  44. print("111")
  45. # 动态生成场站数据路径
  46. print("222")
  47. # 加载数据
  48. data_objects = self.loader.get_material_region()
  49. config['area_id'] = data_objects.area_id
  50. area_id = data_objects.area_id
  51. print("333")
  52. # 数据合并
  53. print(data_nwps.nwp)
  54. print(data_nwps.nwp_v)
  55. print("累加的区域装机量{},实际区域装机量{}".format(data_nwps.total_cap, data_objects.area_cap))
  56. train_data = pd.merge(data_nwps.nwp_v_h, data_objects.power, on=config['col_time'])
  57. print("444")
  58. # 模型训练
  59. model = ModelTrainer(train_data, capacity=data_objects.area_cap, config=config)
  60. model.train(pre_area=True)
  61. print("555")
  62. return {'status': 'success', 'area_id': area_id}
  63. except Exception as e:
  64. logging.error(f"Area {area_id} failed: {str(e)}")
  65. return {'status': 'failed', 'area_id': area_id}
  66. class TaskPre(object):
  67. def __init__(self, loader):
  68. self.loader = loader
  69. def station_task(self, config):
  70. """场站级训练任务"""
  71. station_id = -99
  72. try:
  73. print("111")
  74. station_id = config['station_id']
  75. # 动态生成场站数据路径
  76. print("222")
  77. # 加载数据
  78. data_objects = self.loader.get_material(station_id)
  79. local_weights = self.loader.add_weights(data_objects)
  80. print("333")
  81. # 数据合并
  82. pre_data = data_objects.nwp_v
  83. print("444")
  84. # 模型训练
  85. # model = ModelTrainer(station_id, train_data, capacity=data_objects.cap, gpu_id=config.get('gpu_assignment'))
  86. model = ModelPre(pre_data, capacity=data_objects.cap, config=config)
  87. model.predict()
  88. print("555")
  89. return {'status': 'success', 'station_id': station_id, 'weights': local_weights}
  90. except Exception as e:
  91. logging.error(f"Station {station_id} failed: {str(e)}")
  92. return {'status': 'failed', 'station_id': station_id}
  93. def region_task(self, config, data_nwps):
  94. """区域级训练任务"""
  95. area_id = -99
  96. try:
  97. print("111")
  98. # 动态生成场站数据路径
  99. print("222")
  100. # 加载数据
  101. data_objects = self.loader.get_material_region()
  102. config['area_id'] = data_objects.area_id
  103. area_id = data_objects.area_id
  104. print("333")
  105. # 数据合并
  106. print(data_nwps.nwp)
  107. print(data_nwps.nwp_v)
  108. print("累加的区域装机量{},实际区域装机量{}".format(data_nwps.total_cap, data_objects.area_cap))
  109. pre_data = data_nwps.nwp_v
  110. print("444")
  111. # 模型训练
  112. model = ModelPre(pre_data, capacity=data_objects.area_cap, config=config)
  113. model.predict(pre_area=True)
  114. print("555")
  115. return {'status': 'success', 'area_id': area_id}
  116. except Exception as e:
  117. logging.error(f"Area {area_id} failed: {str(e)}")
  118. return {'status': 'failed', 'area_id': area_id}