#!/usr/bin/env python # -*- coding:utf-8 -*- # @FileName :task_worker.py # @Time :2025/4/29 11:05 # @Author :David # @Company: shenyang JY import logging import pandas as pd from app.model.tf_model_train import ModelTrainer from app.model.tf_region_train import RegionTrainer from app.model.material import MaterialLoader class Task(object): def __init__(self, loader): self.loader = loader def station_task(self, config): """场站级训练任务""" station_id = -99 try: print("111") station_id = config['station_id'] # 动态生成场站数据路径 print("222") # 加载数据 data_objects = self.loader.get_material(station_id) local_weights = self.loader.add_weights(data_objects) print("333") # 数据合并 train_data = pd.merge(data_objects.nwp_v_h, data_objects.power, on=config['col_time']) print("444") # 模型训练 # model = ModelTrainer(station_id, train_data, capacity=data_objects.cap, gpu_id=config.get('gpu_assignment')) model = ModelTrainer(train_data, capacity=data_objects.cap, config=config) model.train() print("555") return {'status': 'success', 'station_id': station_id, 'weights': local_weights} except Exception as e: logging.error(f"Station {station_id} failed: {str(e)}") return {'status': 'failed', 'station_id': station_id} def region_task(self, config, data_nwps): """区域级训练任务""" area_id = -99 try: print("111") # 动态生成场站数据路径 print("222") # 加载数据 data_objects = self.loader.get_material_region() config['area_id'] = data_objects.area_id area_id = data_objects.area_id print("333") # 数据合并 print(data_nwps.nwp) print(data_nwps.nwp_v) train_data = pd.merge(data_nwps.nwp_v_h, data_objects.power, on=config['col_time']) print("444") # 模型训练 model = ModelTrainer(train_data, capacity=data_objects.cap, config=config) model.train() print("555") return {'status': 'success', 'area_id': area_id} except Exception as e: logging.error(f"Area {area_id} failed: {str(e)}") return {'status': 'failed', 'area_id': area_id}