task_worker.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :task_worker.py
  4. # @Time :2025/4/29 11:05
  5. # @Author :David
  6. # @Company: shenyang JY
  7. import logging
  8. import pandas as pd
  9. from app.model.tf_model_train import ModelTrainer
  10. from app.model.tf_region_train import RegionTrainer
  11. from app.model.material import MaterialLoader
  12. def station_task(config):
  13. """场站级训练任务"""
  14. try:
  15. print("111")
  16. station_id = config['station_id']
  17. mate = MaterialLoader(base_path=config['input_file'])
  18. # 动态生成场站数据路径
  19. print("222")
  20. # 加载数据
  21. data_objects = mate.get_material(station_id)
  22. print("333")
  23. # 数据合并
  24. train_data = pd.merge(data_objects.nwp_v_h, data_objects.power, on=config['col_time'])
  25. print("444")
  26. # 模型训练
  27. # model = ModelTrainer(station_id, train_data, capacity=data_objects.cap, gpu_id=config.get('gpu_assignment'))
  28. model = ModelTrainer(train_data, capacity=data_objects.cap, config=config)
  29. model.train()
  30. print("555")
  31. return {'status': 'success', 'station_id': station_id}
  32. except Exception as e:
  33. logging.error(f"Station {station_id} failed: {str(e)}")
  34. return {'status': 'failed', 'station_id': station_id}
  35. def region_task(config):
  36. """区域级训练任务"""
  37. try:
  38. print("111")
  39. station_id = config['station_id']
  40. mate = MaterialLoader(base_path=config['input_file'])
  41. # 动态生成场站数据路径
  42. print("222")
  43. # 加载数据
  44. data_objects = mate.get_material_region()
  45. print("333")
  46. # 数据合并
  47. train_data = pd.merge(data_objects.nwp_v_h, data_objects.power, on=config['col_time'])
  48. print("444")
  49. # 模型训练
  50. model = ModelTrainer(train_data, capacity=data_objects.cap, config=config)
  51. model.train()
  52. print("555")
  53. return {'status': 'success', 'station_id': station_id}
  54. except Exception as e:
  55. logging.error(f"Station {station_id} failed: {str(e)}")
  56. return {'status': 'failed', 'station_id': station_id}