async_query_task.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :sync_query.py
  4. # @Time :2025/3/5 12:55
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from flask import jsonify
  8. import threading
  9. import uuid
  10. import time
  11. import traceback
  12. from collections import defaultdict
  13. # 全局存储训练进度(生产环境建议使用Redis)
  14. training_progress = defaultdict(dict)
  15. progress_lock = threading.Lock()
  16. @app.route('/nn_bp_training', methods=['POST'])
  17. def start_training():
  18. """启动训练任务接口"""
  19. task_id = str(uuid.uuid4())
  20. # 初始化任务进度
  21. with progress_lock:
  22. training_progress[task_id] = {
  23. 'status': 'pending',
  24. 'progress': 0,
  25. 'message': '任务已创建',
  26. 'result': None,
  27. 'start_time': time.time(),
  28. 'end_time': None
  29. }
  30. # 启动异步训练线程
  31. thread = threading.Thread(
  32. target=async_training_task,
  33. args=(task_id,),
  34. daemon=True
  35. )
  36. thread.start()
  37. return jsonify({
  38. 'success': 1,
  39. 'task_id': task_id,
  40. 'message': '训练任务已启动'
  41. })
  42. @app.route('/training_progress/')
  43. def get_progress(task_id):
  44. """查询训练进度接口"""
  45. with progress_lock:
  46. progress = training_progress.get(task_id, {
  47. 'status': 'not_found',
  48. 'progress': 0,
  49. 'message': '任务不存在'
  50. })
  51. return jsonify(progress)
  52. @app.route('/training_progress/')
  53. def get_progress(task_id):
  54. """查询训练进度接口"""
  55. with progress_lock:
  56. progress = training_progress.get(task_id, {
  57. 'status': 'not_found',
  58. 'progress': 0,
  59. 'message': '任务不存在'
  60. })
  61. return jsonify(progress)
  62. def async_training_task(task_id):
  63. """异步训练任务"""
  64. args = {} # 根据实际情况获取参数
  65. result = {}
  66. start_time = time.time()
  67. try:
  68. # 更新任务状态
  69. update_progress(task_id, 10, '数据准备中...')
  70. # ------------ 数据准备 ------------
  71. train_data = get_data_from_mongo(args)
  72. train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(
  73. train_data, bp_data=True)
  74. # ------------ 模型训练 ------------
  75. update_progress(task_id, 30, '模型训练中...')
  76. bp.opt.Model['input_size'] = train_x.shape[1]
  77. # 包装训练函数以跟踪进度
  78. def training_callback(epoch, total_epoch):
  79. progress = 30 + 60 * (epoch / total_epoch)
  80. update_progress(task_id, progress, f'训练第 {epoch}/{total_epoch} 轮')
  81. bp_model = bp.training([train_x, train_y, valid_x, valid_y],
  82. callback=training_callback)
  83. # ------------ 保存结果 ------------
  84. update_progress(task_id, 95, '保存模型中...')
  85. args['params'] = json.dumps(args)
  86. args['descr'] = '测试'
  87. args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
  88. insert_trained_model_into_mongo(bp_model, args)
  89. insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
  90. # 最终结果
  91. result.update({
  92. 'success': 1,
  93. 'args': args,
  94. 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
  95. 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  96. })
  97. update_progress(task_id, 100, '训练完成', result=result)
  98. except Exception as e:
  99. error_msg = traceback.format_exc().replace("\n", "\t")
  100. result = {
  101. 'success': 0,
  102. 'msg': error_msg,
  103. 'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
  104. 'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  105. }
  106. update_progress(task_id, -1, '训练失败', result=result)
  107. def update_progress(task_id, progress, message, result=None):
  108. """更新进度工具函数"""
  109. with progress_lock:
  110. training_progress[task_id]['progress'] = progress
  111. training_progress[task_id]['message'] = message
  112. training_progress[task_id]['status'] = 'running'
  113. if progress >= 100:
  114. training_progress[task_id]['status'] = 'completed'
  115. training_progress[task_id]['end_time'] = time.time()
  116. elif progress < 0:
  117. training_progress[task_id]['status'] = 'failed'
  118. training_progress[task_id]['end_time'] = time.time()
  119. if result:
  120. training_progress[task_id]['result'] = result