David 2 months ago
parent
commit
1fdb521fb9

+ 12 - 0
models_processing/model_koi/async_query_task.py

@@ -62,6 +62,18 @@ def get_progress(task_id):
     return jsonify(progress)
     return jsonify(progress)
 
 
 
 
+@app.route('/training_progress/')
+def get_progress(task_id):
+    """查询训练进度接口"""
+    with progress_lock:
+        progress = training_progress.get(task_id, {
+            'status': 'not_found',
+            'progress': 0,
+            'message': '任务不存在'
+        })
+
+    return jsonify(progress)
+
 def async_training_task(task_id):
 def async_training_task(task_id):
     """异步训练任务"""
     """异步训练任务"""
     args = {}  # 根据实际情况获取参数
     args = {}  # 根据实际情况获取参数

+ 1 - 1
models_processing/model_koi/tf_lstm_pre.py

@@ -52,7 +52,7 @@ def model_prediction_bp():
     try:
     try:
         pre_data = get_data_from_mongo(args)
         pre_data = get_data_from_mongo(args)
         feature_scaler, target_scaler = get_scaler_model_from_mongo(args)
         feature_scaler, target_scaler = get_scaler_model_from_mongo(args)
-        scaled_pre_x = dh.pre_data_handler(pre_data, feature_scaler, args)
+        scaled_pre_x = dh.pre_data_handler(pre_data, feature_scaler)
         ts.get_model(args)
         ts.get_model(args)
         # result = bp.predict(scaled_pre_x, args)
         # result = bp.predict(scaled_pre_x, args)
         res = list(chain.from_iterable(target_scaler.inverse_transform([ts.predict(scaled_pre_x).flatten()])))
         res = list(chain.from_iterable(target_scaler.inverse_transform([ts.predict(scaled_pre_x).flatten()])))

+ 3 - 8
models_processing/model_koi/tf_lstm_train.py

@@ -6,16 +6,13 @@
 # @Company: shenyang JY
 # @Company: shenyang JY
 import json, copy
 import json, copy
 import numpy as np
 import numpy as np
-from flask import Flask, request
-import traceback
+from flask import Flask, request, jsonify
+import traceback, uuid
 import logging, argparse
 import logging, argparse
 from data_processing.data_operation.data_handler import DataHandler
 from data_processing.data_operation.data_handler import DataHandler
-import time, yaml
+import time, yaml, threading
 from models_processing.model_koi.tf_lstm import TSHandler
 from models_processing.model_koi.tf_lstm import TSHandler
-from models_processing.model_koi.tf_cnn import CNNHandler
-
 from common.database_dml_koi import *
 from common.database_dml_koi import *
-import matplotlib.pyplot as plt
 from common.logs import Log
 from common.logs import Log
 logger = logging.getLogger()
 logger = logging.getLogger()
 # logger = Log('models-processing').logger
 # logger = Log('models-processing').logger
@@ -29,7 +26,6 @@ with app.app_context():
 
 
     dh = DataHandler(logger, args)
     dh = DataHandler(logger, args)
     ts = TSHandler(logger, args)
     ts = TSHandler(logger, args)
-    # ts = CNNHandler(logger, args)
     global opt
     global opt
 
 
 @app.before_request
 @app.before_request
@@ -48,7 +44,6 @@ def model_training_bp():
     # 获取程序开始时间
     # 获取程序开始时间
     start_time = time.time()
     start_time = time.time()
     result = {}
     result = {}
-    success = 0
     print("Program starts execution!")
     print("Program starts execution!")
     try:
     try:
         # ------------ 获取数据,预处理训练数据 ------------
         # ------------ 获取数据,预处理训练数据 ------------