Ver Fonte

03051302

David há 2 meses atrás
pai
commit
445e01368e

+ 16 - 6
common/database_dml_koi.py

@@ -370,12 +370,22 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
         # ------------------------- 内存优化加载 -------------------------
         if model_doc:
             model_data = model_doc['model_data']  # 获取模型的二进制数据
-            # 将二进制数据加载到 BytesIO 缓冲区
-            model_buffer = BytesIO(model_data)
-            # 从缓冲区加载模型
-            # 使用 h5py 和 BytesIO 从内存中加载模型
-            with h5py.File(model_buffer, 'r') as f:
-                model = tf.keras.models.load_model(f, custom_objects=custom_objects)
+            # # 将二进制数据加载到 BytesIO 缓冲区
+            # model_buffer = BytesIO(model_data)
+            # # 确保指针在起始位置
+            # model_buffer.seek(0)
+            # # 从缓冲区加载模型
+            # # 使用 h5py 和 BytesIO 从内存中加载模型
+            # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
+            #     model = tf.keras.models.load_model(f, custom_objects=custom_objects)
+            # 创建临时文件
+            with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp_file:
+                tmp_file.write(model_data)
+                tmp_file_path = tmp_file.name  # 获取临时文件路径
+
+            # 从临时文件加载模型
+            model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects)
+
             print(f"{args['model_name']}模型成功从 MongoDB 加载!")
             return model
     except tf.errors.NotFoundError as e:

+ 136 - 0
models_processing/model_koi/async_query_task.py

@@ -0,0 +1,136 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# @FileName  :sync_query.py
+# @Time      :2025/3/5 12:55
+# @Author    :David
+# @Company: shenyang JY
+
+
+from flask import jsonify
+import threading
+import uuid
+import time
+import traceback
+from collections import defaultdict
+
+# 全局存储训练进度(生产环境建议使用Redis)
+training_progress = defaultdict(dict)
+progress_lock = threading.Lock()
+
+
+@app.route('/nn_bp_training', methods=['POST'])
+def start_training():
+    """启动训练任务接口"""
+    task_id = str(uuid.uuid4())
+
+    # 初始化任务进度
+    with progress_lock:
+        training_progress[task_id] = {
+            'status': 'pending',
+            'progress': 0,
+            'message': '任务已创建',
+            'result': None,
+            'start_time': time.time(),
+            'end_time': None
+        }
+
+    # 启动异步训练线程
+    thread = threading.Thread(
+        target=async_training_task,
+        args=(task_id,),
+        daemon=True
+    )
+    thread.start()
+
+    return jsonify({
+        'success': 1,
+        'task_id': task_id,
+        'message': '训练任务已启动'
+    })
+
+
+@app.route('/training_progress/')
+def get_progress(task_id):
+    """查询训练进度接口"""
+    with progress_lock:
+        progress = training_progress.get(task_id, {
+            'status': 'not_found',
+            'progress': 0,
+            'message': '任务不存在'
+        })
+
+    return jsonify(progress)
+
+
+def async_training_task(task_id):
+    """异步训练任务"""
+    args = {}  # 根据实际情况获取参数
+    result = {}
+    start_time = time.time()
+
+    try:
+        # 更新任务状态
+        update_progress(task_id, 10, '数据准备中...')
+
+        # ------------ 数据准备 ------------
+        train_data = get_data_from_mongo(args)
+        train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(
+            train_data, bp_data=True)
+
+        # ------------ 模型训练 ------------
+        update_progress(task_id, 30, '模型训练中...')
+        bp.opt.Model['input_size'] = train_x.shape[1]
+
+        # 包装训练函数以跟踪进度
+        def training_callback(epoch, total_epoch):
+            progress = 30 + 60 * (epoch / total_epoch)
+            update_progress(task_id, progress, f'训练第 {epoch}/{total_epoch} 轮')
+
+        bp_model = bp.training([train_x, train_y, valid_x, valid_y],
+                               callback=training_callback)
+
+        # ------------ 保存结果 ------------
+        update_progress(task_id, 95, '保存模型中...')
+        args['params'] = json.dumps(args)
+        args['descr'] = '测试'
+        args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
+        insert_trained_model_into_mongo(bp_model, args)
+        insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
+
+        # 最终结果
+        result.update({
+            'success': 1,
+            'args': args,
+            'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
+            'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
+        })
+
+        update_progress(task_id, 100, '训练完成', result=result)
+
+    except Exception as e:
+        error_msg = traceback.format_exc().replace("\n", "\t")
+        result = {
+            'success': 0,
+            'msg': error_msg,
+            'start_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)),
+            'end_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
+        }
+        update_progress(task_id, -1, '训练失败', result=result)
+
+
+def update_progress(task_id, progress, message, result=None):
+    """更新进度工具函数"""
+    with progress_lock:
+        training_progress[task_id]['progress'] = progress
+        training_progress[task_id]['message'] = message
+        training_progress[task_id]['status'] = 'running'
+
+        if progress >= 100:
+            training_progress[task_id]['status'] = 'completed'
+            training_progress[task_id]['end_time'] = time.time()
+        elif progress < 0:
+            training_progress[task_id]['status'] = 'failed'
+            training_progress[task_id]['end_time'] = time.time()
+
+        if result:
+            training_progress[task_id]['result'] = result

+ 18 - 18
models_processing/model_koi/tf_bp_train.py

@@ -47,24 +47,24 @@ def model_training_bp():
     result = {}
     success = 0
     print("Program starts execution!")
-    # try:
-    # ------------ 获取数据,预处理训练数据 ------------
-    train_data = get_data_from_mongo(args)
-    train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(train_data, bp_data=True)
-    # ------------ 训练模型 ------------
-    bp.opt.Model['input_size'] = train_x.shape[1]
-    bp_model = bp.training([train_x, train_y, valid_x, valid_y])
-    # ------------ 保存模型 ------------
-    args['params'] = json.dumps(args)
-    args['descr'] = '测试'
-    args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
-    insert_trained_model_into_mongo(bp_model, args)
-    insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
-    success = 1
-    # except Exception as e:
-    #     my_exception = traceback.format_exc()
-    #     my_exception.replace("\n", "\t")
-    #     result['msg'] = my_exception
+    try:
+        # ------------ 获取数据,预处理训练数据 ------------
+        train_data = get_data_from_mongo(args)
+        train_x, train_y, valid_x, valid_y, scaled_train_bytes, scaled_target_bytes = dh.train_data_handler(train_data, bp_data=True)
+        # ------------ 训练模型 ------------
+        bp.opt.Model['input_size'] = train_x.shape[1]
+        bp_model = bp.training([train_x, train_y, valid_x, valid_y])
+        # ------------ 保存模型 ------------
+        args['params'] = json.dumps(args)
+        args['descr'] = '测试'
+        args['gen_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
+        insert_trained_model_into_mongo(bp_model, args)
+        insert_scaler_model_into_mongo(scaled_train_bytes, scaled_target_bytes, args)
+        success = 1
+    except Exception as e:
+        my_exception = traceback.format_exc()
+        my_exception.replace("\n", "\t")
+        result['msg'] = my_exception
     end_time = time.time()
 
     result['success'] = success

+ 1 - 1
models_processing/model_koi/tf_cnn.py

@@ -12,7 +12,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoa
 from tensorflow.keras import optimizers, regularizers
 from models_processing.losses.loss_cdq import rmse
 import numpy as np
-from common.database_dml import *
+from common.database_dml_koi import *
 from threading import Lock
 import argparse
 model_lock = Lock()

+ 1 - 1
models_processing/model_koi/tf_cnn_train.py

@@ -12,7 +12,7 @@ import logging, argparse
 from data_processing.data_operation.data_handler import DataHandler
 import time, yaml
 from models_processing.model_koi.tf_cnn import CNNHandler
-from common.database_dml import *
+from common.database_dml_koi import *
 import matplotlib.pyplot as plt
 from common.logs import Log
 logger = logging.getLogger()

+ 1 - 1
models_processing/model_koi/tf_lstm.py

@@ -11,7 +11,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoa
 from tensorflow.keras import optimizers, regularizers
 from models_processing.losses.loss_cdq import rmse
 import numpy as np
-from common.database_dml import *
+from common.database_dml_koi import *
 from threading import Lock
 import argparse
 model_lock = Lock()

+ 1 - 1
models_processing/model_koi/tf_lstm_train.py

@@ -14,7 +14,7 @@ import time, yaml
 from models_processing.model_koi.tf_lstm import TSHandler
 from models_processing.model_koi.tf_cnn import CNNHandler
 
-from common.database_dml import *
+from common.database_dml_koi import *
 import matplotlib.pyplot as plt
 from common.logs import Log
 logger = logging.getLogger()