async_query_task.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :sync_query.py
  4. # @Time :2025/3/5 12:55
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from flask import jsonify
  8. import threading
  9. import uuid
  10. import time
  11. import traceback
  12. from collections import defaultdict
  13. # 全局存储训练进度(生产环境建议使用Redis)
  14. training_progress = defaultdict(dict)
  15. progress_lock = threading.Lock()
  16. @app.route('/nn_bp_training', methods=['POST'])
  17. def start_training():
  18. """启动训练任务接口"""
  19. task_id = str(uuid.uuid4())
  20. # 初始化任务进度
  21. with progress_lock:
  22. training_progress[task_id] = {
  23. 'status': 'pending',
  24. 'progress': 0,
  25. 'message': '任务已创建',
  26. 'result': None,
  27. 'start_time': time.time(),
  28. 'end_time': None
  29. }
  30. # 启动异步训练线程
  31. thread = threading.Thread(
  32. target=async_training_task,
  33. args=(task_id,),
  34. daemon=True
  35. )
  36. thread.start()
  37. return jsonify({
  38. 'success': 1,
  39. 'task_id': task_id,
  40. 'message': '训练任务已启动'
  41. })
  42. @app.route('/training_progress/')
  43. def get_progress(task_id):
  44. """查询训练进度接口"""
  45. with progress_lock:
  46. progress = training_progress.get(task_id, {
  47. 'status': 'not_found',
  48. 'progress': 0,
  49. 'message': '任务不存在'
  50. })
  51. return jsonify(progress)
  52. def async_training_task(task_id):
  53. """异步训练任务"""
  54. args = {} # 根据实际情况获取参数
  55. result = {}
  56. start_time = time.time()
  57. try:
  58. # 更新任务状态
  59. update_progress(task_id, 10, '数据准备中...')
  60. # ------------ 数据准备 ------------
  61. train_data = get_data_from_mongo(args)
  62. train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(
  63. train_data, bp_data=True)
  64. # ------------ 模型训练 ------------
  65. update_progress(task_id, 30, '模型训练中...')
  66. bp.opt.Model['input_size'] = train_x.shape[1]
  67. # 包装训练函数以跟踪进度
  68. def training_callback(epoch, total_epoch):
  69. progress = 30 + 60 * (epoch / total_epoch)
  70. update_progress(task_id, progress, f'训练第 {epoch}/{total_epoch} 轮')
  71. bp_model = bp.training([train_x, train_y, valid_x, valid_y],
  72. callback=training_callback)
  73. # ------------ 保存结果 ------------
  74. update_progress(task_id, 95, '保存模型中...')
  75. args['params'] = json.dumps(args)
  76. args['descr'] = '测试'
  77. args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
  78. insert_trained_model_into_mongo(bp_model, args)
  79. insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
  80. # 最终结果
  81. result.update({
  82. 'success': 1,
  83. 'args': args,
  84. 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
  85. 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  86. })
  87. update_progress(task_id, 100, '训练完成', result=result)
  88. except Exception as e:
  89. error_msg = traceback.format_exc().replace("\n", "\t")
  90. result = {
  91. 'success': 0,
  92. 'msg': error_msg,
  93. 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
  94. 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  95. }
  96. update_progress(task_id, -1, '训练失败', result=result)
  97. def update_progress(task_id, progress, message, result=None):
  98. """更新进度工具函数"""
  99. with progress_lock:
  100. training_progress[task_id]['progress'] = progress
  101. training_progress[task_id]['message'] = message
  102. training_progress[task_id]['status'] = 'running'
  103. if progress >= 100:
  104. training_progress[task_id]['status'] = 'completed'
  105. training_progress[task_id]['end_time'] = time.time()
  106. elif progress < 0:
  107. training_progress[task_id]['status'] = 'failed'
  108. training_progress[task_id]['end_time'] = time.time()
  109. if result:
  110. training_progress[task_id]['result'] = result