123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- #!/usr/bin/env python
- # -*- coding:utf-8 -*-
- # @FileName :tf_model_train.py
- # @Time :2025/4/29 14:05
- # @Author :David
- # @Company: shenyang JY
- import logging
- import os, json
- import time
- import traceback
- from pathlib import Path
- from copy import deepcopy
- import pandas as pd
- from typing import Dict, Any
- from app.common.tf_lstm import TSHandler
- from app.common.dbmg import MongoUtils
- from app.common.data_handler import DataHandler, write_number_to_file
- from app.model.config import myargparse
- args = myargparse(discription="算法批量修模", add_help=False)
- mgUtils = MongoUtils(logger)
- class ModelTrainer:
- """模型训练器封装类"""
- def __init__(self,
- input_file,
- args,
- train_data: pd.DataFrame,
- capacity: float,
- gpu_id: int = None,
- config: Dict[str, Any] = None
- ):
- self.train_data = train_data
- self.capacity = capacity
- self.gpu_id = gpu_id
- self.config = config or {}
- self._setup_resources()
- # 初始化组件
- self.logger = logging.getLogger(self.__class__.__name__)
- self.input_file = input_file
- self.args = args # 从原始配置导入
- self.dh = DataHandler(self.logger, self.args)
- self.ts = TSHandler(self.logger, self.args)
- self.mgUtils = MongoUtils(self.logger)
- def _setup_resources(self):
- """GPU资源分配"""
- if self.gpu_id is not None:
- os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
- self.logger.info(f"GPU {self.gpu_id} allocated")
- def train(self):
- """执行训练流程"""
- # 获取程序开始时间
- start_time = time.time()
- success = 0
- farm_id = self.input_file.split('/')[-2]
- output_file = self.input_file.replace('IN', 'OUT')
- status_file = 'STATUS.TXT'
- try:
- # ------------ 获取数据,预处理训练数据 ------------
- self.dh.opt.cap = self.capacity
- train_x, valid_x, train_y, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(self.train_data)
- self.ts.opt.Model['input_size'] = train_x.shape[2]
- # ------------ 训练模型,保存模型 ------------
- # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
- # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
- model = self.ts.train_init() if self.ts.opt.Model['add_train'] else self.ts.get_keras_model(self.ts.opt)
- if self.ts.opt.Model['add_train']:
- if model:
- feas = json.loads(self.ts.model_params).get('features', self.dh.opt.features)
- if set(feas).issubset(set(self.dh.opt.features)):
- self.dh.opt.features = list(feas)
- train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(train_data)
- else:
- model = self.ts.get_keras_model(self.ts.opt)
- self.logger.info("训练数据特征,不满足,加强训练模型特征")
- else:
- model = self.ts.get_keras_model(self.ts.opt)
- # 执行训练
- trained_model = self.ts.training(model, [train_x, valid_x, train_y, valid_y])
- # 模型持久化
- success = 1
- # 更新算法状态:1. 启动成功
- write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite')
- # ------------ 组装模型数据 ------------
- self.opt.Model['features'] = ','.join(self.dh.opt.features)
- self.args.update({
- 'params': json.dumps(local_params),
- 'descr': f'南网竞赛-{farm_id}',
- 'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- 'model_table': local_params['model_table'] + farm_id,
- 'scaler_table': local_params['scaler_table'] + farm_id
- })
- mgUtils.insert_trained_model_into_mongo(trained_model, local_params)
- mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, local_params)
- # 更新算法状态:正常结束
- write_number_to_file(os.path.join(output_file, status_file), 2, 2)
- return True
- except Exception as e:
- self._handle_error(e)
- return False
- def _initialize_model(self):
- """模型初始化策略"""
- if self.ts.opt.Model['add_train']:
- pretrained = self.ts.train_init()
- return pretrained if self._check_feature_compatibility(pretrained) else self.ts.get_keras_model()
- return self.ts.get_keras_model()
- def _check_feature_compatibility(self, model) -> bool:
- """检查特征兼容性"""
- # 原始逻辑中的特征校验实现
- pass
- def _handle_error(self, error: Exception):
- """统一错误处理"""
- error_msg = traceback.format_exc()
- self.logger.error(f"Training failed: {str(error)}\n{error_msg}")
- # 使用示例
- if __name__ == "__main__":
- config = {
- 'base_path': '/data/power_forecast',
- 'capacities': {
- '1001': 2.5,
- '1002': 3.0,
- # ... 其他场站配置
- },
- 'gpu_assignment': [0, 1, 2, 3] # 可用GPU列表
- }
- orchestrator = TrainingOrchestrator(
- station_ids=['1001', '1002', '1003'], # 实际场景下生成数百个ID
- config=config,
- max_workers=4 # 根据GPU数量设置
- )
- orchestrator.execute()
|