|
@@ -0,0 +1,136 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- coding:utf-8 -*-
|
|
|
+# @FileName :sync_query.py
|
|
|
+# @Time :2025/3/5 12:55
|
|
|
+# @Author :David
|
|
|
+# @Company: shenyang JY
|
|
|
+
|
|
|
+
|
|
|
+from flask import jsonify
|
|
|
+import threading
|
|
|
+import uuid
|
|
|
+import time
|
|
|
+import traceback
|
|
|
+from collections import defaultdict
|
|
|
+
|
|
|
+# 全局存储训练进度(生产环境建议使用Redis)
|
|
|
+training_progress = defaultdict(dict)
|
|
|
+progress_lock = threading.Lock()
|
|
|
+
|
|
|
+
|
|
|
+@app.route('/nn_bp_training', methods=['POST'])
|
|
|
+def start_training():
|
|
|
+ """启动训练任务接口"""
|
|
|
+ task_id = str(uuid.uuid4())
|
|
|
+
|
|
|
+ # 初始化任务进度
|
|
|
+ with progress_lock:
|
|
|
+ training_progress[task_id] = {
|
|
|
+ 'status': 'pending',
|
|
|
+ 'progress': 0,
|
|
|
+ 'message': '任务已创建',
|
|
|
+ 'result': None,
|
|
|
+ 'start_time': time.time(),
|
|
|
+ 'end_time': None
|
|
|
+ }
|
|
|
+
|
|
|
+ # 启动异步训练线程
|
|
|
+ thread = threading.Thread(
|
|
|
+ target=async_training_task,
|
|
|
+ args=(task_id,),
|
|
|
+ daemon=True
|
|
|
+ )
|
|
|
+ thread.start()
|
|
|
+
|
|
|
+ return jsonify({
|
|
|
+ 'success': 1,
|
|
|
+ 'task_id': task_id,
|
|
|
+ 'message': '训练任务已启动'
|
|
|
+ })
|
|
|
+
|
|
|
+
|
|
|
+@app.route('/training_progress/')
|
|
|
+def get_progress(task_id):
|
|
|
+ """查询训练进度接口"""
|
|
|
+ with progress_lock:
|
|
|
+ progress = training_progress.get(task_id, {
|
|
|
+ 'status': 'not_found',
|
|
|
+ 'progress': 0,
|
|
|
+ 'message': '任务不存在'
|
|
|
+ })
|
|
|
+
|
|
|
+ return jsonify(progress)
|
|
|
+
|
|
|
+
|
|
|
+def async_training_task(task_id):
|
|
|
+ """异步训练任务"""
|
|
|
+ args = {} # 根据实际情况获取参数
|
|
|
+ result = {}
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 更新任务状态
|
|
|
+ update_progress(task_id, 10, '数据准备中...')
|
|
|
+
|
|
|
+ # ------------ 数据准备 ------------
|
|
|
+ train_data = get_data_from_mongo(args)
|
|
|
+ train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(
|
|
|
+ train_data, bp_data=True)
|
|
|
+
|
|
|
+ # ------------ 模型训练 ------------
|
|
|
+ update_progress(task_id, 30, '模型训练中...')
|
|
|
+ bp.opt.Model['input_size'] = train_x.shape[1]
|
|
|
+
|
|
|
+ # 包装训练函数以跟踪进度
|
|
|
+ def training_callback(epoch, total_epoch):
|
|
|
+ progress = 30 + 60 * (epoch / total_epoch)
|
|
|
+ update_progress(task_id, progress, f'训练第 {epoch}/{total_epoch} 轮')
|
|
|
+
|
|
|
+ bp_model = bp.training([train_x, train_y, valid_x, valid_y],
|
|
|
+ callback=training_callback)
|
|
|
+
|
|
|
+ # ------------ 保存结果 ------------
|
|
|
+ update_progress(task_id, 95, '保存模型中...')
|
|
|
+ args['params'] = json.dumps(args)
|
|
|
+ args['descr'] = '测试'
|
|
|
+ args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
|
|
|
+ insert_trained_model_into_mongo(bp_model, args)
|
|
|
+ insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
|
|
|
+
|
|
|
+ # 最终结果
|
|
|
+ result.update({
|
|
|
+ 'success': 1,
|
|
|
+ 'args': args,
|
|
|
+ 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
|
|
|
+ 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
|
|
|
+ })
|
|
|
+
|
|
|
+ update_progress(task_id, 100, '训练完成', result=result)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ error_msg = traceback.format_exc().replace("\n", "\t")
|
|
|
+ result = {
|
|
|
+ 'success': 0,
|
|
|
+ 'msg': error_msg,
|
|
|
+ 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
|
|
|
+ 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
|
|
|
+ }
|
|
|
+ update_progress(task_id, -1, '训练失败', result=result)
|
|
|
+
|
|
|
+
|
|
|
+def update_progress(task_id, progress, message, result=None):
|
|
|
+ """更新进度工具函数"""
|
|
|
+ with progress_lock:
|
|
|
+ training_progress[task_id]['progress'] = progress
|
|
|
+ training_progress[task_id]['message'] = message
|
|
|
+ training_progress[task_id]['status'] = 'running'
|
|
|
+
|
|
|
+ if progress >= 100:
|
|
|
+ training_progress[task_id]['status'] = 'completed'
|
|
|
+ training_progress[task_id]['end_time'] = time.time()
|
|
|
+ elif progress < 0:
|
|
|
+ training_progress[task_id]['status'] = 'failed'
|
|
|
+ training_progress[task_id]['end_time'] = time.time()
|
|
|
+
|
|
|
+ if result:
|
|
|
+ training_progress[task_id]['result'] = result
|