|
@@ -1,144 +0,0 @@
|
|
|
-#!/usr/bin/env python
|
|
|
-# -*- coding:utf-8 -*-
|
|
|
-# @FileName :tf_model_train.py
|
|
|
-# @Time :2025/4/29 14:05
|
|
|
-# @Author :David
|
|
|
-# @Company: shenyang JY
|
|
|
-
|
|
|
-import logging
|
|
|
-import os, json
|
|
|
-import time, argparse
|
|
|
-import traceback
|
|
|
-import pandas as pd
|
|
|
-from typing import Dict, Any
|
|
|
-from app.common.tf_lstm import TSHandler
|
|
|
-from app.common.dbmg import MongoUtils
|
|
|
-from app.common.data_handler_region import DataHandlerRegion, write_number_to_file
|
|
|
-from app.common.config import logger, parser
|
|
|
-
|
|
|
-class RegionTrainer:
|
|
|
- """模型训练器封装类"""
|
|
|
-
|
|
|
- def __init__(self,
|
|
|
- train_data: pd.DataFrame,
|
|
|
- capacity: float,
|
|
|
- config: Dict[str, Any] = None
|
|
|
- ):
|
|
|
- self.config = config
|
|
|
- self.logger = logger
|
|
|
- self.train_data = train_data
|
|
|
- self.capacity = capacity
|
|
|
- self.gpu_id = config.get('gpu_assignment')
|
|
|
- self._setup_resources()
|
|
|
-
|
|
|
- # 初始化组件
|
|
|
- self.input_file = config.get("input_file")
|
|
|
- self.opt = argparse.Namespace(**config)
|
|
|
- self.dh = DataHandlerRegion(logger, self.opt)
|
|
|
- self.ts = TSHandler(logger, self.opt)
|
|
|
- self.mgUtils = MongoUtils(logger)
|
|
|
-
|
|
|
- def _setup_resources(self):
|
|
|
- """GPU资源分配"""
|
|
|
- if self.gpu_id is not None:
|
|
|
- os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
|
|
|
- self.logger.info(f"GPU {self.gpu_id} allocated")
|
|
|
-
|
|
|
-
|
|
|
- def train(self):
|
|
|
- """执行训练流程"""
|
|
|
- # 获取程序开始时间
|
|
|
- start_time = time.time()
|
|
|
- success = 0
|
|
|
- print("aaa")
|
|
|
- farm_id = self.input_file.split('/')[-2]
|
|
|
- output_file = self.input_file.replace('IN', 'OUT')
|
|
|
- status_file = 'STATUS.TXT'
|
|
|
- try:
|
|
|
- # ------------ 获取数据,预处理训练数据 ------------
|
|
|
- self.dh.opt.cap = self.capacity
|
|
|
- train_x, valid_x, train_y, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(self.train_data)
|
|
|
- self.ts.opt.Model['input_size'] = train_x.shape[2]
|
|
|
- # ------------ 训练模型,保存模型 ------------
|
|
|
- # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
|
|
|
- # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
|
|
|
- print("bbb")
|
|
|
- model = self.ts.train_init() if self.ts.opt.Model['add_train'] else self.ts.get_keras_model(self.ts.opt)
|
|
|
- if self.ts.opt.Model['add_train']:
|
|
|
- if model:
|
|
|
- feas = json.loads(self.ts.model_params).get('features', self.dh.opt.features)
|
|
|
- if set(feas).issubset(set(self.dh.opt.features)):
|
|
|
- self.dh.opt.features = list(feas)
|
|
|
- train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(self.train_data)
|
|
|
- else:
|
|
|
- model = self.ts.get_keras_model(self.ts.opt)
|
|
|
- self.logger.info("训练数据特征,不满足,加强训练模型特征")
|
|
|
- else:
|
|
|
- model = self.ts.get_keras_model(self.ts.opt)
|
|
|
- print("ccc")
|
|
|
- # 执行训练
|
|
|
- trained_model = self.ts.training(model, [train_x, valid_x, train_y, valid_y])
|
|
|
- # 模型持久化
|
|
|
- success = 1
|
|
|
- print('ddd')
|
|
|
- # 更新算法状态:1. 启动成功
|
|
|
- write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite')
|
|
|
- # ------------ 组装模型数据 ------------
|
|
|
- self.opt.Model['features'] = ','.join(self.dh.opt.features)
|
|
|
- self.config.update({
|
|
|
- 'params': json.dumps(self.config['Model']),
|
|
|
- 'descr': f'南网竞赛-{farm_id}',
|
|
|
- 'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
|
|
|
- 'model_table': self.config['model_table'] + farm_id,
|
|
|
- 'scaler_table': self.config['scaler_table'] + farm_id
|
|
|
- })
|
|
|
- self.mgUtils.insert_trained_model_into_mongo(trained_model, self.config)
|
|
|
- self.mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, self.config)
|
|
|
- # 更新算法状态:正常结束
|
|
|
- print("eee")
|
|
|
- write_number_to_file(os.path.join(output_file, status_file), 2, 2)
|
|
|
- return True
|
|
|
- except Exception as e:
|
|
|
- self._handle_error(e)
|
|
|
- return False
|
|
|
-
|
|
|
- def _initialize_model(self):
|
|
|
- """模型初始化策略"""
|
|
|
- if self.ts.opt.Model['add_train']:
|
|
|
- pretrained = self.ts.train_init()
|
|
|
- return pretrained if self._check_feature_compatibility(pretrained) else self.ts.get_keras_model()
|
|
|
- return self.ts.get_keras_model()
|
|
|
-
|
|
|
- def _check_feature_compatibility(self, model) -> bool:
|
|
|
- """检查特征兼容性"""
|
|
|
- # 原始逻辑中的特征校验实现
|
|
|
- pass
|
|
|
-
|
|
|
-
|
|
|
- def _handle_error(self, error: Exception):
|
|
|
- """统一错误处理"""
|
|
|
- error_msg = traceback.format_exc()
|
|
|
- self.logger.error(f"Training failed: {str(error)}\n{error_msg}")
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-# 使用示例
|
|
|
-if __name__ == "__main__":
|
|
|
- config = {
|
|
|
- 'base_path': '/data/power_forecast',
|
|
|
- 'capacities': {
|
|
|
- '1001': 2.5,
|
|
|
- '1002': 3.0,
|
|
|
- # ... 其他场站配置
|
|
|
- },
|
|
|
- 'gpu_assignment': [0, 1, 2, 3] # 可用GPU列表
|
|
|
- }
|
|
|
-
|
|
|
- orchestrator = TrainingOrchestrator(
|
|
|
- station_ids=['1001', '1002', '1003'], # 实际场景下生成数百个ID
|
|
|
- config=config,
|
|
|
- max_workers=4 # 根据GPU数量设置
|
|
|
- )
|
|
|
- orchestrator.execute()
|