task_worker.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :task_worker.py
  4. # @Time :2025/4/29 11:05
  5. # @Author :David
  6. # @Company: shenyang JY
  7. import logging
  8. import pandas as pd
  9. from app.model.tf_model_train import ModelTrainer
  10. from app.model.tf_region_train import RegionTrainer
  11. from app.model.material import MaterialLoader
  12. class Task(object):
  13. def __init__(self, loader):
  14. self.loader = loader
  15. def station_task(self, config):
  16. """场站级训练任务"""
  17. station_id = -99
  18. try:
  19. print("111")
  20. station_id = config['station_id']
  21. # 动态生成场站数据路径
  22. print("222")
  23. # 加载数据
  24. data_objects = self.loader.get_material(station_id)
  25. local_weights = self.loader.add_weights(data_objects)
  26. print("333")
  27. # 数据合并
  28. train_data = pd.merge(data_objects.nwp_v_h, data_objects.power, on=config['col_time'])
  29. print("444")
  30. # 模型训练
  31. # model = ModelTrainer(station_id, train_data, capacity=data_objects.cap, gpu_id=config.get('gpu_assignment'))
  32. model = ModelTrainer(train_data, capacity=data_objects.cap, config=config)
  33. model.train()
  34. print("555")
  35. return {'status': 'success', 'station_id': station_id, 'weights': local_weights}
  36. except Exception as e:
  37. logging.error(f"Station {station_id} failed: {str(e)}")
  38. return {'status': 'failed', 'station_id': station_id}
  39. def region_task(self, config, data_nwps):
  40. """区域级训练任务"""
  41. area_id = -99
  42. try:
  43. print("111")
  44. # 动态生成场站数据路径
  45. print("222")
  46. # 加载数据
  47. data_objects = self.loader.get_material_region()
  48. config['area_id'] = data_objects.area_id
  49. area_id = data_objects.area_id
  50. print("333")
  51. # 数据合并
  52. print(data_nwps.nwp)
  53. print(data_nwps.nwp_v)
  54. train_data = pd.merge(data_nwps.nwp_v_h, data_objects.power, on=config['col_time'])
  55. print("444")
  56. # 模型训练
  57. model = ModelTrainer(train_data, capacity=data_objects.cap, config=config)
  58. model.train()
  59. print("555")
  60. return {'status': 'success', 'area_id': area_id}
  61. except Exception as e:
  62. logging.error(f"Area {area_id} failed: {str(e)}")
  63. return {'status': 'failed', 'area_id': area_id}