#!/usr/bin/env python # -*- coding:utf-8 -*- # @FileName :tf_model_train.py # @Time :2025/4/29 14:05 # @Author :David # @Company: shenyang JY import logging import os, json import time, argparse import traceback import pandas as pd from typing import Dict, Any from app.common.tf_lstm import TSHandler from app.common.dbmg import MongoUtils from app.common.data_handler import DataHandler, write_number_to_file from app.common.config import logger, parser class ModelTrainer: """模型训练器封装类""" def __init__(self, train_data: pd.DataFrame, capacity: float, config: Dict[str, Any] = None, ): self.config = config self.logger = logger self.train_data = train_data self.capacity = capacity self.gpu_id = config.get('gpu_assignment') self._setup_resources() # 初始化组件 self.input_file = config.get("input_file") self.opt = argparse.Namespace(**config) self.dh = DataHandler(logger, self.opt) self.ts = TSHandler(logger, self.opt) self.mgUtils = MongoUtils(logger) def _setup_resources(self): """GPU资源分配""" if self.gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id) self.logger.info(f"GPU {self.gpu_id} allocated") def train(self, pre_area=False): """执行训练流程""" # 获取程序开始时间 start_time = time.time() success = 0 print("aaa") # 预测编号:场站级,场站id,区域级,区域id pre_id = self.config['area_id'] if pre_area else self.config['station_id'] pre_type = 'a' if pre_area else 's' output_file = self.input_file.replace('IN', 'OUT') status_file = 'STATUS.TXT' try: # ------------ 获取数据,预处理训练数据 ------------ self.dh.opt.cap = self.capacity train_x, valid_x, train_y, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(self.train_data) self.ts.opt.Model['input_size'] = train_x.shape[2] # ------------ 训练模型,保存模型 ------------ # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据 # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型 print("bbb") model = self.ts.train_init() if self.ts.opt.Model['add_train'] else self.ts.get_keras_model(self.ts.opt) if self.ts.opt.Model['add_train']: if model: feas = json.loads(self.ts.model_params).get('features', self.dh.opt.features) if set(feas).issubset(set(self.dh.opt.features)): self.dh.opt.features = list(feas) train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(self.train_data) else: model = self.ts.get_keras_model(self.ts.opt) self.logger.info("训练数据特征,不满足,加强训练模型特征") else: model = self.ts.get_keras_model(self.ts.opt) print("ccc") # 执行训练 trained_model = self.ts.training(model, [train_x, valid_x, train_y, valid_y]) # 模型持久化 success = 1 print('ddd') # 更新算法状态:1. 启动成功 write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite') # ------------ 组装模型数据 ------------ self.opt.Model['features'] = ','.join(self.dh.opt.features) self.config.update({ 'params': json.dumps(self.config['Model']), 'descr': f'南网竞赛-{pre_id}', 'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), 'model_table': self.config['model_table'] + f'_{pre_type}_' + str(pre_id), 'scaler_table': self.config['scaler_table'] + f'_{pre_type}_'+ str(pre_id) }) self.mgUtils.insert_trained_model_into_mongo(trained_model, self.config) self.mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, self.config) # 更新算法状态:正常结束 print("eee") write_number_to_file(os.path.join(output_file, status_file), 2, 2) return True except Exception as e: self._handle_error(e) return False def _initialize_model(self): """模型初始化策略""" if self.ts.opt.Model['add_train']: pretrained = self.ts.train_init() return pretrained if self._check_feature_compatibility(pretrained) else self.ts.get_keras_model() return self.ts.get_keras_model() def _check_feature_compatibility(self, model) -> bool: """检查特征兼容性""" # 原始逻辑中的特征校验实现 pass def _handle_error(self, error: Exception): """统一错误处理""" error_msg = traceback.format_exc() self.logger.error(f"Training failed: {str(error)}\n{error_msg}") # 使用示例 if __name__ == "__main__": config = { 'base_path': '/data/power_forecast', 'capacities': { '1001': 2.5, '1002': 3.0, # ... 其他场站配置 }, 'gpu_assignment': [0, 1, 2, 3] # 可用GPU列表 } orchestrator = TrainingOrchestrator( station_ids=['1001', '1002', '1003'], # 实际场景下生成数百个ID config=config, max_workers=4 # 根据GPU数量设置 ) orchestrator.execute()