|
@@ -174,12 +174,12 @@ def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any])
|
|
|
|
|
|
try:
|
|
try:
|
|
# ------------------------- 临时文件处理 -------------------------
|
|
# ------------------------- 临时文件处理 -------------------------
|
|
- fd, temp_path = tempfile.mkstemp(suffix='.h5')
|
|
|
|
|
|
+ fd, temp_path = tempfile.mkstemp(suffix='.keras')
|
|
os.close(fd) # 立即释放文件锁
|
|
os.close(fd) # 立即释放文件锁
|
|
|
|
|
|
# ------------------------- 模型保存 -------------------------
|
|
# ------------------------- 模型保存 -------------------------
|
|
try:
|
|
try:
|
|
- model.save(temp_path, save_format='h5')
|
|
|
|
|
|
+ model.save(temp_path) # 不指定save_format,默认使用keras新格式
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise RuntimeError(f"模型保存失败: {str(e)}") from e
|
|
raise RuntimeError(f"模型保存失败: {str(e)}") from e
|
|
|
|
|
|
@@ -337,6 +337,7 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
|
|
# ------------------------- 环境配置 -------------------------
|
|
# ------------------------- 环境配置 -------------------------
|
|
mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
|
|
mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
|
|
client = None
|
|
client = None
|
|
|
|
+ tmp_file_path = None # 用于跟踪临时文件路径
|
|
try:
|
|
try:
|
|
# ------------------------- 数据库连接 -------------------------
|
|
# ------------------------- 数据库连接 -------------------------
|
|
client = MongoClient(
|
|
client = MongoClient(
|
|
@@ -379,7 +380,7 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
|
|
# with h5py.File(model_buffer, 'r', driver='fileobj') as f:
|
|
# with h5py.File(model_buffer, 'r', driver='fileobj') as f:
|
|
# model = tf.keras.models.load_model(f, custom_objects=custom_objects)
|
|
# model = tf.keras.models.load_model(f, custom_objects=custom_objects)
|
|
# 创建临时文件
|
|
# 创建临时文件
|
|
- with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp_file:
|
|
|
|
|
|
+ with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
|
|
tmp_file.write(model_data)
|
|
tmp_file.write(model_data)
|
|
tmp_file_path = tmp_file.name # 获取临时文件路径
|
|
tmp_file_path = tmp_file.name # 获取临时文件路径
|
|
|
|
|
|
@@ -400,7 +401,109 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
|
|
# ------------------------- 资源清理 -------------------------
|
|
# ------------------------- 资源清理 -------------------------
|
|
if client:
|
|
if client:
|
|
client.close()
|
|
client.close()
|
|
|
|
+ # 确保删除临时文件
|
|
|
|
+ if tmp_file_path and os.path.exists(tmp_file_path):
|
|
|
|
+ try:
|
|
|
|
+ os.remove(tmp_file_path)
|
|
|
|
+ print(f"🧹 已清理临时文件: {tmp_file_path}")
|
|
|
|
+ except Exception as cleanup_err:
|
|
|
|
+ print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_keras_model_from_mongo(
|
|
|
|
+ args: Dict[str, Any],
|
|
|
|
+ custom_objects: Optional[Dict[str, Any]] = None
|
|
|
|
+) -> Optional[tf.keras.Model]:
|
|
|
|
+ """
|
|
|
|
+ 从MongoDB获取指定模型的最新版本(支持Keras格式)
|
|
|
|
+
|
|
|
|
+ 参数:
|
|
|
|
+ args : dict - 包含以下键的字典:
|
|
|
|
+ mongodb_database: 数据库名称
|
|
|
|
+ model_table: 集合名称
|
|
|
|
+ model_name: 要获取的模型名称
|
|
|
|
+ custom_objects: dict - 自定义Keras对象字典
|
|
|
|
+
|
|
|
|
+ 返回:
|
|
|
|
+ tf.keras.Model - 加载成功的Keras模型
|
|
|
|
+ """
|
|
|
|
+ # ------------------------- 参数校验 -------------------------
|
|
|
|
+ required_keys = {'mongodb_database', 'model_table', 'model_name'}
|
|
|
|
+ if missing := required_keys - args.keys():
|
|
|
|
+ raise ValueError(f"❌ 缺失必要参数: {missing}")
|
|
|
|
+
|
|
|
|
+ # ------------------------- 环境配置 -------------------------
|
|
|
|
+ mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
|
|
|
|
+ client = None
|
|
|
|
+ tmp_file_path = None # 用于跟踪临时文件路径
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ # ------------------------- 数据库连接 -------------------------
|
|
|
|
+ client = MongoClient(
|
|
|
|
+ mongo_uri,
|
|
|
|
+ maxPoolSize=10,
|
|
|
|
+ socketTimeoutMS=5000
|
|
|
|
+ )
|
|
|
|
+ db = client[args['mongodb_database']]
|
|
|
|
+ collection = db[args['model_table']]
|
|
|
|
|
|
|
|
+ # ------------------------- 索引维护 -------------------------
|
|
|
|
+ index_name = "model_gen_time_idx"
|
|
|
|
+ if index_name not in collection.index_information():
|
|
|
|
+ collection.create_index(
|
|
|
|
+ [("model_name", 1), ("gen_time", DESCENDING)],
|
|
|
|
+ name=index_name
|
|
|
|
+ )
|
|
|
|
+ print("⏱️ 已创建复合索引")
|
|
|
|
+
|
|
|
|
+ # ------------------------- 高效查询 -------------------------
|
|
|
|
+ model_doc = collection.find_one(
|
|
|
|
+ {"model_name": args['model_name']},
|
|
|
|
+ sort=[('gen_time', DESCENDING)],
|
|
|
|
+ projection={"model_data": 1, "gen_time": 1}
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if not model_doc:
|
|
|
|
+ print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ # ------------------------- 内存优化加载 -------------------------
|
|
|
|
+ model_data = model_doc['model_data']
|
|
|
|
+
|
|
|
|
+ # 创建临时文件(自动删除)
|
|
|
|
+ with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
|
|
|
|
+ tmp_file.write(model_data)
|
|
|
|
+ tmp_file_path = tmp_file.name # 记录文件路径
|
|
|
|
+
|
|
|
|
+ # 从临时文件加载模型
|
|
|
|
+ model = tf.keras.models.load_model(
|
|
|
|
+ tmp_file_path,
|
|
|
|
+ custom_objects=custom_objects
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
|
|
|
|
+ return model
|
|
|
|
+
|
|
|
|
+ except tf.errors.NotFoundError as e:
|
|
|
|
+ print(f"❌ 模型结构缺失关键组件: {str(e)}")
|
|
|
|
+ raise RuntimeError("模型架构不完整") from e
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"❌ 系统异常: {str(e)}")
|
|
|
|
+ raise
|
|
|
|
+
|
|
|
|
+ finally:
|
|
|
|
+ # ------------------------- 资源清理 -------------------------
|
|
|
|
+ if client:
|
|
|
|
+ client.close()
|
|
|
|
+
|
|
|
|
+ # 确保删除临时文件
|
|
|
|
+ if tmp_file_path and os.path.exists(tmp_file_path):
|
|
|
|
+ try:
|
|
|
|
+ os.remove(tmp_file_path)
|
|
|
|
+ print(f"🧹 已清理临时文件: {tmp_file_path}")
|
|
|
|
+ except Exception as cleanup_err:
|
|
|
|
+ print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
|
|
|
|
|
|
def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]:
|
|
def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]:
|
|
"""
|
|
"""
|