#!/usr/bin/env python # -*- coding:utf-8 -*- # @FileName :sync_query.py # @Time :2025/3/5 12:55 # @Author :David # @Company: shenyang JY from flask import jsonify import threading import uuid import time import traceback from collections import defaultdict # 全局存储训练进度(生产环境建议使用Redis) training_progress = defaultdict(dict) progress_lock = threading.Lock() @app.route('/nn_bp_training', methods=['POST']) def start_training(): """启动训练任务接口""" task_id = str(uuid.uuid4()) # 初始化任务进度 with progress_lock: training_progress[task_id] = { 'status': 'pending', 'progress': 0, 'message': '任务已创建', 'result': None, 'start_time': time.time(), 'end_time': None } # 启动异步训练线程 thread = threading.Thread( target=async_training_task, args=(task_id,), daemon=True ) thread.start() return jsonify({ 'success': 1, 'task_id': task_id, 'message': '训练任务已启动' }) @app.route('/training_progress/') def get_progress(task_id): """查询训练进度接口""" with progress_lock: progress = training_progress.get(task_id, { 'status': 'not_found', 'progress': 0, 'message': '任务不存在' }) return jsonify(progress) @app.route('/training_progress/') def get_progress(task_id): """查询训练进度接口""" with progress_lock: progress = training_progress.get(task_id, { 'status': 'not_found', 'progress': 0, 'message': '任务不存在' }) return jsonify(progress) def async_training_task(task_id): """异步训练任务""" args = {} # 根据实际情况获取参数 result = {} start_time = time.time() try: # 更新任务状态 update_progress(task_id, 10, '数据准备中...') # ------------ 数据准备 ------------ train_data = get_data_from_mongo(args) train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler( train_data, bp_data=True) # ------------ 模型训练 ------------ update_progress(task_id, 30, '模型训练中...') bp.opt.Model['input_size'] = train_x.shape[1] # 包装训练函数以跟踪进度 def training_callback(epoch, total_epoch): progress = 30 + 60 * (epoch / total_epoch) update_progress(task_id, progress, f'训练第 {epoch}/{total_epoch} 轮') bp_model = bp.training([train_x, train_y, valid_x, valid_y], callback=training_callback) # ------------ 保存结果 ------------ update_progress(task_id, 95, '保存模型中...') args['params'] = json.dumps(args) args['descr'] = '测试' args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) insert_trained_model_into_mongo(bp_model, args) insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args) # 最终结果 result.update({ 'success': 1, 'args': args, 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)), 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), }) update_progress(task_id, 100, '训练完成', result=result) except Exception as e: error_msg = traceback.format_exc().replace("\n", "\t") result = { 'success': 0, 'msg': error_msg, 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)), 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), } update_progress(task_id, -1, '训练失败', result=result) def update_progress(task_id, progress, message, result=None): """更新进度工具函数""" with progress_lock: training_progress[task_id]['progress'] = progress training_progress[task_id]['message'] = message training_progress[task_id]['status'] = 'running' if progress >= 100: training_progress[task_id]['status'] = 'completed' training_progress[task_id]['end_time'] = time.time() elif progress < 0: training_progress[task_id]['status'] = 'failed' training_progress[task_id]['end_time'] = time.time() if result: training_progress[task_id]['result'] = result