123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- #!/usr/bin/env python
- # -*- coding:utf-8 -*-
- # @FileName :sync_query.py
- # @Time :2025/3/5 12:55
- # @Author :David
- # @Company: shenyang JY
- from flask import jsonify
- import threading
- import uuid
- import time
- import traceback
- from collections import defaultdict
- # 全局存储训练进度(生产环境建议使用Redis)
- training_progress = defaultdict(dict)
- progress_lock = threading.Lock()
- @app.route('/nn_bp_training', methods=['POST'])
- def start_training():
- """启动训练任务接口"""
- task_id = str(uuid.uuid4())
- # 初始化任务进度
- with progress_lock:
- training_progress[task_id] = {
- 'status': 'pending',
- 'progress': 0,
- 'message': '任务已创建',
- 'result': None,
- 'start_time': time.time(),
- 'end_time': None
- }
- # 启动异步训练线程
- thread = threading.Thread(
- target=async_training_task,
- args=(task_id,),
- daemon=True
- )
- thread.start()
- return jsonify({
- 'success': 1,
- 'task_id': task_id,
- 'message': '训练任务已启动'
- })
- @app.route('/training_progress/')
- def get_progress(task_id):
- """查询训练进度接口"""
- with progress_lock:
- progress = training_progress.get(task_id, {
- 'status': 'not_found',
- 'progress': 0,
- 'message': '任务不存在'
- })
- return jsonify(progress)
- @app.route('/training_progress/')
- def get_progress(task_id):
- """查询训练进度接口"""
- with progress_lock:
- progress = training_progress.get(task_id, {
- 'status': 'not_found',
- 'progress': 0,
- 'message': '任务不存在'
- })
- return jsonify(progress)
- def async_training_task(task_id):
- """异步训练任务"""
- args = {} # 根据实际情况获取参数
- result = {}
- start_time = time.time()
- try:
- # 更新任务状态
- update_progress(task_id, 10, '数据准备中...')
- # ------------ 数据准备 ------------
- train_data = get_data_from_mongo(args)
- train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(
- train_data, bp_data=True)
- # ------------ 模型训练 ------------
- update_progress(task_id, 30, '模型训练中...')
- bp.opt.Model['input_size'] = train_x.shape[1]
- # 包装训练函数以跟踪进度
- def training_callback(epoch, total_epoch):
- progress = 30 + 60 * (epoch / total_epoch)
- update_progress(task_id, progress, f'训练第 {epoch}/{total_epoch} 轮')
- bp_model = bp.training([train_x, train_y, valid_x, valid_y],
- callback=training_callback)
- # ------------ 保存结果 ------------
- update_progress(task_id, 95, '保存模型中...')
- args['params'] = json.dumps(args)
- args['descr'] = '测试'
- args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
- insert_trained_model_into_mongo(bp_model, args)
- insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
- # 最终结果
- result.update({
- 'success': 1,
- 'args': args,
- 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
- 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- })
- update_progress(task_id, 100, '训练完成', result=result)
- except Exception as e:
- error_msg = traceback.format_exc().replace("\n", "\t")
- result = {
- 'success': 0,
- 'msg': error_msg,
- 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
- 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- }
- update_progress(task_id, -1, '训练失败', result=result)
- def update_progress(task_id, progress, message, result=None):
- """更新进度工具函数"""
- with progress_lock:
- training_progress[task_id]['progress'] = progress
- training_progress[task_id]['message'] = message
- training_progress[task_id]['status'] = 'running'
- if progress >= 100:
- training_progress[task_id]['status'] = 'completed'
- training_progress[task_id]['end_time'] = time.time()
- elif progress < 0:
- training_progress[task_id]['status'] = 'failed'
- training_progress[task_id]['end_time'] = time.time()
- if result:
- training_progress[task_id]['result'] = result
|