tf_model_train.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_model_train.py
  4. # @Time :2025/4/29 14:05
  5. # @Author :David
  6. # @Company: shenyang JY
  7. import logging
  8. import os, json
  9. import time
  10. import traceback
  11. from pathlib import Path
  12. from copy import deepcopy
  13. import pandas as pd
  14. from typing import Dict, Any
  15. from app.common.tf_lstm import TSHandler
  16. from app.common.dbmg import MongoUtils
  17. from app.common.data_handler import DataHandler, write_number_to_file
  18. from app.model.config import myargparse
  19. args = myargparse(discription="算法批量修模", add_help=False)
  20. mgUtils = MongoUtils(logger)
  21. class ModelTrainer:
  22. """模型训练器封装类"""
  23. def __init__(self,
  24. input_file,
  25. args,
  26. train_data: pd.DataFrame,
  27. capacity: float,
  28. gpu_id: int = None,
  29. config: Dict[str, Any] = None
  30. ):
  31. self.train_data = train_data
  32. self.capacity = capacity
  33. self.gpu_id = gpu_id
  34. self.config = config or {}
  35. self._setup_resources()
  36. # 初始化组件
  37. self.logger = logging.getLogger(self.__class__.__name__)
  38. self.input_file = input_file
  39. self.args = args # 从原始配置导入
  40. self.dh = DataHandler(self.logger, self.args)
  41. self.ts = TSHandler(self.logger, self.args)
  42. self.mgUtils = MongoUtils(self.logger)
  43. def _setup_resources(self):
  44. """GPU资源分配"""
  45. if self.gpu_id is not None:
  46. os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
  47. self.logger.info(f"GPU {self.gpu_id} allocated")
  48. def train(self):
  49. """执行训练流程"""
  50. # 获取程序开始时间
  51. start_time = time.time()
  52. success = 0
  53. farm_id = self.input_file.split('/')[-2]
  54. output_file = self.input_file.replace('IN', 'OUT')
  55. status_file = 'STATUS.TXT'
  56. try:
  57. # ------------ 获取数据,预处理训练数据 ------------
  58. self.dh.opt.cap = self.capacity
  59. train_x, valid_x, train_y, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(self.train_data)
  60. self.ts.opt.Model['input_size'] = train_x.shape[2]
  61. # ------------ 训练模型,保存模型 ------------
  62. # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
  63. # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
  64. model = self.ts.train_init() if self.ts.opt.Model['add_train'] else self.ts.get_keras_model(self.ts.opt)
  65. if self.ts.opt.Model['add_train']:
  66. if model:
  67. feas = json.loads(self.ts.model_params).get('features', self.dh.opt.features)
  68. if set(feas).issubset(set(self.dh.opt.features)):
  69. self.dh.opt.features = list(feas)
  70. train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(train_data)
  71. else:
  72. model = self.ts.get_keras_model(self.ts.opt)
  73. self.logger.info("训练数据特征,不满足,加强训练模型特征")
  74. else:
  75. model = self.ts.get_keras_model(self.ts.opt)
  76. # 执行训练
  77. trained_model = self.ts.training(model, [train_x, valid_x, train_y, valid_y])
  78. # 模型持久化
  79. success = 1
  80. # 更新算法状态:1. 启动成功
  81. write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite')
  82. # ------------ 组装模型数据 ------------
  83. self.opt.Model['features'] = ','.join(self.dh.opt.features)
  84. self.args.update({
  85. 'params': json.dumps(local_params),
  86. 'descr': f'南网竞赛-{farm_id}',
  87. 'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  88. 'model_table': local_params['model_table'] + farm_id,
  89. 'scaler_table': local_params['scaler_table'] + farm_id
  90. })
  91. mgUtils.insert_trained_model_into_mongo(trained_model, local_params)
  92. mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, local_params)
  93. # 更新算法状态:正常结束
  94. write_number_to_file(os.path.join(output_file, status_file), 2, 2)
  95. return True
  96. except Exception as e:
  97. self._handle_error(e)
  98. return False
  99. def _initialize_model(self):
  100. """模型初始化策略"""
  101. if self.ts.opt.Model['add_train']:
  102. pretrained = self.ts.train_init()
  103. return pretrained if self._check_feature_compatibility(pretrained) else self.ts.get_keras_model()
  104. return self.ts.get_keras_model()
  105. def _check_feature_compatibility(self, model) -> bool:
  106. """检查特征兼容性"""
  107. # 原始逻辑中的特征校验实现
  108. pass
  109. def _handle_error(self, error: Exception):
  110. """统一错误处理"""
  111. error_msg = traceback.format_exc()
  112. self.logger.error(f"Training failed: {str(error)}\n{error_msg}")
  113. # 使用示例
  114. if __name__ == "__main__":
  115. config = {
  116. 'base_path': '/data/power_forecast',
  117. 'capacities': {
  118. '1001': 2.5,
  119. '1002': 3.0,
  120. # ... 其他场站配置
  121. },
  122. 'gpu_assignment': [0, 1, 2, 3] # 可用GPU列表
  123. }
  124. orchestrator = TrainingOrchestrator(
  125. station_ids=['1001', '1002', '1003'], # 实际场景下生成数百个ID
  126. config=config,
  127. max_workers=4 # 根据GPU数量设置
  128. )
  129. orchestrator.execute()