|
@@ -1,4 +1,6 @@
|
|
-from pymongo import MongoClient, UpdateOne
|
|
|
|
|
|
+import pymongo
|
|
|
|
+from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING
|
|
|
|
+from pymongo.errors import PyMongoError
|
|
import pandas as pd
|
|
import pandas as pd
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy import create_engine
|
|
import pickle
|
|
import pickle
|
|
@@ -6,6 +8,8 @@ from io import BytesIO
|
|
import joblib
|
|
import joblib
|
|
import h5py
|
|
import h5py
|
|
import tensorflow as tf
|
|
import tensorflow as tf
|
|
|
|
+import os
|
|
|
|
+import tempfile
|
|
|
|
|
|
def get_data_from_mongo(args):
|
|
def get_data_from_mongo(args):
|
|
mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
|
|
mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
|
|
@@ -103,10 +107,12 @@ def insert_data_into_mongo(res_df, args):
|
|
# 批量执行更新/插入操作
|
|
# 批量执行更新/插入操作
|
|
if operations:
|
|
if operations:
|
|
result = collection.bulk_write(operations)
|
|
result = collection.bulk_write(operations)
|
|
|
|
+ client.close()
|
|
print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
|
|
print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
|
|
else:
|
|
else:
|
|
# 追加模式:直接插入新数据
|
|
# 追加模式:直接插入新数据
|
|
collection.insert_many(data_dict)
|
|
collection.insert_many(data_dict)
|
|
|
|
+ client.close()
|
|
print("Data inserted successfully!")
|
|
print("Data inserted successfully!")
|
|
|
|
|
|
|
|
|
|
@@ -139,6 +145,7 @@ def insert_pickle_model_into_mongo(model, args):
|
|
print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
|
|
print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
|
|
collection = db[mongodb_write_table] # 集合名称
|
|
collection = db[mongodb_write_table] # 集合名称
|
|
collection.insert_one(model_data)
|
|
collection.insert_one(model_data)
|
|
|
|
+ client.close()
|
|
print("model inserted successfully!")
|
|
print("model inserted successfully!")
|
|
|
|
|
|
|
|
|
|
@@ -161,49 +168,76 @@ def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,a
|
|
db[model_table].drop()
|
|
db[model_table].drop()
|
|
print(f"Collection '{model_table} already exist, deleted successfully!")
|
|
print(f"Collection '{model_table} already exist, deleted successfully!")
|
|
model_table = db[model_table]
|
|
model_table = db[model_table]
|
|
- # 创建 BytesIO 缓冲区
|
|
|
|
- model_buffer = BytesIO()
|
|
|
|
- # 将模型保存为 HDF5 格式到内存 (BytesIO)
|
|
|
|
- model.save(model_buffer, save_format='h5')
|
|
|
|
- # 将指针移到缓冲区的起始位置
|
|
|
|
- model_buffer.seek(0)
|
|
|
|
- # 获取模型的二进制数据
|
|
|
|
- model_data = model_buffer.read()
|
|
|
|
- # 将模型保存到 MongoDB
|
|
|
|
- model_table.insert_one({
|
|
|
|
- "model_name": model_name,
|
|
|
|
- "model_data": model_data
|
|
|
|
- })
|
|
|
|
- print("模型成功保存到 MongoDB!")
|
|
|
|
|
|
+ fd, temp_path = None, None
|
|
|
|
+ client = None
|
|
|
|
|
|
-def insert_trained_model_into_mongo(model, args):
|
|
|
|
- mongodb_connection,mongodb_database,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
|
|
|
|
- args['mongodb_database'],args['model_table'],args['model_name'])
|
|
|
|
|
|
+ try:
|
|
|
|
+ # ------------------------- 临时文件处理 -------------------------
|
|
|
|
+ fd, temp_path = tempfile.mkstemp(suffix='.keras')
|
|
|
|
+ os.close(fd) # 立即释放文件锁
|
|
|
|
|
|
- gen_time, params_json, descr = args['gen_time'], args['params'], args['descr']
|
|
|
|
- client = MongoClient(mongodb_connection)
|
|
|
|
- db = client[mongodb_database]
|
|
|
|
- if model_table in db.list_collection_names():
|
|
|
|
- db[model_table].drop()
|
|
|
|
- print(f"Collection '{model_table} already exist, deleted successfully!")
|
|
|
|
- model_table = db[model_table]
|
|
|
|
- # 创建 BytesIO 缓冲区
|
|
|
|
- model_buffer = BytesIO()
|
|
|
|
- # 将模型保存为 HDF5 格式到内存 (BytesIO)
|
|
|
|
- model.save(model_buffer, save_format='h5')
|
|
|
|
- # 将指针移到缓冲区的起始位置
|
|
|
|
- model_buffer.seek(0)
|
|
|
|
- # 获取模型的二进制数据
|
|
|
|
- model_data = model_buffer.read()
|
|
|
|
- # 将模型保存到 MongoDB
|
|
|
|
- model_table.insert_one({
|
|
|
|
- "model_name": model_name,
|
|
|
|
- "model_data": model_data,
|
|
|
|
- "gen_time": gen_time,
|
|
|
|
- "params": params_json,
|
|
|
|
- "descr": descr
|
|
|
|
- })
|
|
|
|
- print("模型成功保存到 MongoDB!")
|
|
|
|
|
|
+ # ------------------------- 模型保存 -------------------------
|
|
|
|
+ try:
|
|
|
|
+ model.save(temp_path) # 不指定save_format,默认使用keras新格式
|
|
|
|
+ except Exception as e:
|
|
|
|
+ raise RuntimeError(f"模型保存失败: {str(e)}") from e
|
|
|
|
+
|
|
|
|
+ # ------------------------- 数据插入 -------------------------
|
|
|
|
+ with open(temp_path, 'rb') as f:
|
|
|
|
+ result = model_table.insert_one({
|
|
|
|
+ "model_name": args['model_name'],
|
|
|
|
+ "model_data": f.read(),
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
|
|
|
|
+ return str(result.inserted_id)
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ # ------------------------- 异常分类处理 -------------------------
|
|
|
|
+ error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
|
|
|
|
+ print(f"❌ {error_type} - 详细错误: {str(e)}")
|
|
|
|
+ raise # 根据业务需求决定是否重新抛出
|
|
|
|
+
|
|
|
|
+ finally:
|
|
|
|
+ # ------------------------- 资源清理 -------------------------
|
|
|
|
+ if client:
|
|
|
|
+ client.close()
|
|
|
|
+ if temp_path and os.path.exists(temp_path):
|
|
|
|
+ try:
|
|
|
|
+ os.remove(temp_path)
|
|
|
|
+ except PermissionError:
|
|
|
|
+ print(f"⚠️ 临时文件清理失败: {temp_path}")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# def insert_trained_model_into_mongo(model, args):
|
|
|
|
+# mongodb_connection,mongodb_database,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
|
|
|
|
+# args['mongodb_database'],args['model_table'],args['model_name'])
|
|
|
|
+#
|
|
|
|
+# gen_time, params_json, descr = args['gen_time'], args['params'], args['descr']
|
|
|
|
+# client = MongoClient(mongodb_connection)
|
|
|
|
+# db = client[mongodb_database]
|
|
|
|
+# if model_table in db.list_collection_names():
|
|
|
|
+# db[model_table].drop()
|
|
|
|
+# print(f"Collection '{model_table} already exist, deleted successfully!")
|
|
|
|
+# model_table = db[model_table]
|
|
|
|
+#
|
|
|
|
+# # 创建 BytesIO 缓冲区
|
|
|
|
+# model_buffer = BytesIO()
|
|
|
|
+# # 将模型保存为 HDF5 格式到内存 (BytesIO)
|
|
|
|
+# model.save(model_buffer, save_format='h5')
|
|
|
|
+# # 将指针移到缓冲区的起始位置
|
|
|
|
+# model_buffer.seek(0)
|
|
|
|
+# # 获取模型的二进制数据
|
|
|
|
+# model_data = model_buffer.read()
|
|
|
|
+# # 将模型保存到 MongoDB
|
|
|
|
+# model_table.insert_one({
|
|
|
|
+# "model_name": model_name,
|
|
|
|
+# "model_data": model_data,
|
|
|
|
+# "gen_time": gen_time,
|
|
|
|
+# "params": params_json,
|
|
|
|
+# "descr": descr
|
|
|
|
+# })
|
|
|
|
+# print("模型成功保存到 MongoDB!")
|
|
|
|
|
|
def insert_scaler_model_into_mongo(feature_scaler_bytes, scaled_target_bytes, args):
|
|
def insert_scaler_model_into_mongo(feature_scaler_bytes, scaled_target_bytes, args):
|
|
mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
|
|
mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
|
|
@@ -219,6 +253,7 @@ def insert_scaler_model_into_mongo(feature_scaler_bytes, scaled_target_bytes, ar
|
|
"feature_scaler": feature_scaler_bytes.read(),
|
|
"feature_scaler": feature_scaler_bytes.read(),
|
|
"target_scaler": scaled_target_bytes.read()
|
|
"target_scaler": scaled_target_bytes.read()
|
|
})
|
|
})
|
|
|
|
+ client.close()
|
|
print("scaler_model inserted successfully!")
|
|
print("scaler_model inserted successfully!")
|
|
|
|
|
|
|
|
|
|
@@ -233,14 +268,26 @@ def get_h5_model_from_mongo(args, custom=None):
|
|
model_doc = collection.find_one({"model_name": model_name})
|
|
model_doc = collection.find_one({"model_name": model_name})
|
|
if model_doc:
|
|
if model_doc:
|
|
model_data = model_doc['model_data'] # 获取模型的二进制数据
|
|
model_data = model_doc['model_data'] # 获取模型的二进制数据
|
|
- # 将二进制数据加载到 BytesIO 缓冲区
|
|
|
|
- model_buffer = BytesIO(model_data)
|
|
|
|
- # 从缓冲区加载模型
|
|
|
|
- # 使用 h5py 和 BytesIO 从内存中加载模型
|
|
|
|
- with h5py.File(model_buffer, 'r') as f:
|
|
|
|
- model = tf.keras.models.load_model(f, custom_objects=custom)
|
|
|
|
- print(f"{model_name}模型成功从 MongoDB 加载!")
|
|
|
|
|
|
+ # 创建临时文件(自动删除)
|
|
|
|
+ with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
|
|
|
|
+ tmp_file.write(model_data)
|
|
|
|
+ tmp_file_path = tmp_file.name # 记录文件路径
|
|
|
|
+
|
|
|
|
+ # 从临时文件加载模型
|
|
|
|
+ model = tf.keras.models.load_model(
|
|
|
|
+ tmp_file_path,
|
|
|
|
+ custom_objects=custom
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
|
|
client.close()
|
|
client.close()
|
|
|
|
+ # 确保删除临时文件
|
|
|
|
+ if tmp_file_path and os.path.exists(tmp_file_path):
|
|
|
|
+ try:
|
|
|
|
+ os.remove(tmp_file_path)
|
|
|
|
+ print(f"🧹 已清理临时文件: {tmp_file_path}")
|
|
|
|
+ except Exception as cleanup_err:
|
|
|
|
+ print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
|
|
return model
|
|
return model
|
|
else:
|
|
else:
|
|
print(f"未找到model_name为 {model_name} 的模型。")
|
|
print(f"未找到model_name为 {model_name} 的模型。")
|
|
@@ -264,4 +311,5 @@ def get_scaler_model_from_mongo(args, only_feature_scaler=False):
|
|
return feature_scaler
|
|
return feature_scaler
|
|
target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
|
|
target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
|
|
target_scaler = joblib.load(target_scaler_bytes)
|
|
target_scaler = joblib.load(target_scaler_bytes)
|
|
|
|
+ client.close()
|
|
return feature_scaler,target_scaler
|
|
return feature_scaler,target_scaler
|