tf_model_train.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :tf_model_train.py
  4. # @Time :2025/4/29 14:05
  5. # @Author :David
  6. # @Company: shenyang JY
  7. import logging
  8. import os, json
  9. import time, argparse
  10. import traceback
  11. import pandas as pd
  12. from typing import Dict, Any
  13. from app.common.tf_lstm import TSHandler
  14. from app.common.dbmg import MongoUtils
  15. from app.common.data_handler import DataHandler, write_number_to_file
  16. from app.common.config import logger, parser
  17. class ModelTrainer:
  18. """模型训练器封装类"""
  19. def __init__(self,
  20. train_data: pd.DataFrame,
  21. capacity: float,
  22. config: Dict[str, Any] = None,
  23. ):
  24. self.config = config
  25. self.logger = logger
  26. self.train_data = train_data
  27. self.capacity = capacity
  28. self.gpu_id = config.get('gpu_assignment')
  29. self._setup_resources()
  30. # 初始化组件
  31. self.input_file = config.get("input_file")
  32. self.opt = argparse.Namespace(**config)
  33. self.dh = DataHandler(logger, self.opt)
  34. self.ts = TSHandler(logger, self.opt)
  35. self.mgUtils = MongoUtils(logger)
  36. def _setup_resources(self):
  37. """GPU资源分配"""
  38. if self.gpu_id is not None:
  39. os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
  40. self.logger.info(f"GPU {self.gpu_id} allocated")
  41. def train(self, pre_area=False):
  42. """执行训练流程"""
  43. # 获取程序开始时间
  44. start_time = time.time()
  45. success = 0
  46. print("aaa")
  47. # 预测编号:场站级,场站id,区域级,区域id
  48. pre_id = self.config['area_id'] if pre_area else self.config['station_id']
  49. pre_type = 'a' if pre_area else 's'
  50. output_file = self.input_file.replace('IN', 'OUT')
  51. status_file = 'STATUS.TXT'
  52. try:
  53. # ------------ 获取数据,预处理训练数据 ------------
  54. self.dh.opt.cap = self.capacity
  55. train_x, valid_x, train_y, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(self.train_data)
  56. self.ts.opt.Model['input_size'] = train_x.shape[2]
  57. # ------------ 训练模型,保存模型 ------------
  58. # 1. 如果是加强训练模式,先加载预训练模型特征参数,再预处理训练数据
  59. # 2. 如果是普通模式,先预处理训练数据,再根据训练数据特征加载模型
  60. print("bbb")
  61. model = self.ts.train_init() if self.ts.opt.Model['add_train'] else self.ts.get_keras_model(self.ts.opt)
  62. if self.ts.opt.Model['add_train']:
  63. if model:
  64. feas = json.loads(self.ts.model_params).get('features', self.dh.opt.features)
  65. if set(feas).issubset(set(self.dh.opt.features)):
  66. self.dh.opt.features = list(feas)
  67. train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes, scaled_cap = self.dh.train_data_handler(self.train_data)
  68. else:
  69. model = self.ts.get_keras_model(self.ts.opt)
  70. self.logger.info("训练数据特征,不满足,加强训练模型特征")
  71. else:
  72. model = self.ts.get_keras_model(self.ts.opt)
  73. print("ccc")
  74. # 执行训练
  75. trained_model = self.ts.training(model, [train_x, valid_x, train_y, valid_y])
  76. # 模型持久化
  77. success = 1
  78. print('ddd')
  79. # 更新算法状态:1. 启动成功
  80. write_number_to_file(os.path.join(output_file, status_file), 1, 1, 'rewrite')
  81. # ------------ 组装模型数据 ------------
  82. self.opt.Model['features'] = ','.join(self.dh.opt.features)
  83. self.config.update({
  84. 'params': json.dumps(self.config['Model']),
  85. 'descr': f'南网竞赛-{pre_id}',
  86. 'gen_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  87. 'model_table': self.config['model_table'] + f'_{pre_type}_' + pre_id,
  88. 'scaler_table': self.config['scaler_table'] + f'_{pre_type}_'+ pre_id
  89. })
  90. self.mgUtils.insert_trained_model_into_mongo(trained_model, self.config)
  91. self.mgUtils.insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, self.config)
  92. # 更新算法状态:正常结束
  93. print("eee")
  94. write_number_to_file(os.path.join(output_file, status_file), 2, 2)
  95. return True
  96. except Exception as e:
  97. self._handle_error(e)
  98. return False
  99. def _initialize_model(self):
  100. """模型初始化策略"""
  101. if self.ts.opt.Model['add_train']:
  102. pretrained = self.ts.train_init()
  103. return pretrained if self._check_feature_compatibility(pretrained) else self.ts.get_keras_model()
  104. return self.ts.get_keras_model()
  105. def _check_feature_compatibility(self, model) -> bool:
  106. """检查特征兼容性"""
  107. # 原始逻辑中的特征校验实现
  108. pass
  109. def _handle_error(self, error: Exception):
  110. """统一错误处理"""
  111. error_msg = traceback.format_exc()
  112. self.logger.error(f"Training failed: {str(error)}\n{error_msg}")
  113. # 使用示例
  114. if __name__ == "__main__":
  115. config = {
  116. 'base_path': '/data/power_forecast',
  117. 'capacities': {
  118. '1001': 2.5,
  119. '1002': 3.0,
  120. # ... 其他场站配置
  121. },
  122. 'gpu_assignment': [0, 1, 2, 3] # 可用GPU列表
  123. }
  124. orchestrator = TrainingOrchestrator(
  125. station_ids=['1001', '1002', '1003'], # 实际场景下生成数百个ID
  126. config=config,
  127. max_workers=4 # 根据GPU数量设置
  128. )
  129. orchestrator.execute()