dbmg.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import pymongo
  2. from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING
  3. from pymongo.errors import PyMongoError
  4. import pandas as pd
  5. from sqlalchemy import create_engine
  6. import pickle
  7. from io import BytesIO
  8. import joblib
  9. import h5py, os, io
  10. import tensorflow as tf
  11. from typing import Dict, Any, Optional, Union, Tuple
  12. import tempfile
  13. class MongoUtils(object):
  14. def __init__(self, logger):
  15. self.logger = logger
  16. def insert_trained_model_into_mongo(self, model: tf.keras.Model, args: Dict[str, Any]) -> str:
  17. """
  18. 将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型
  19. 参数:
  20. model : keras模型 - 训练好的Keras模型
  21. args : dict - 包含以下键的字典:
  22. mongodb_database: 数据库名称
  23. model_table: 集合名称
  24. model_name: 模型名称
  25. gen_time: 模型生成时间(datetime对象)
  26. params: 模型参数(JSON可序列化对象)
  27. descr: 模型描述文本
  28. """
  29. # ------------------------- 参数校验 -------------------------
  30. required_keys = {'mongodb_database', 'model_table', 'model_name',
  31. 'gen_time', 'params', 'descr'}
  32. if missing := required_keys - args.keys():
  33. raise ValueError(f"缺少必要参数: {missing}")
  34. # ------------------------- 配置解耦 -------------------------
  35. # 从环境变量获取连接信息(更安全)
  36. mongodb_connection = os.getenv("MONGO_URI", "mongodb://root:Syjy*3377@localhost:27017/")
  37. # ------------------------- 资源初始化 -------------------------
  38. fd, temp_path = None, None
  39. client = None
  40. try:
  41. # ------------------------- 临时文件处理 -------------------------
  42. fd, temp_path = tempfile.mkstemp(suffix='.keras')
  43. os.close(fd) # 立即释放文件锁
  44. # ------------------------- 模型保存 -------------------------
  45. try:
  46. model.save(temp_path) # 不指定save_format,默认使用keras新格式
  47. except Exception as e:
  48. raise RuntimeError(f"模型保存失败: {str(e)}") from e
  49. # ------------------------- 数据库连接 -------------------------
  50. client = MongoClient(mongodb_connection)
  51. db = client[args['mongodb_database']]
  52. collection = db[args['model_table']]
  53. # ------------------------- 索引检查 -------------------------
  54. # index_info = collection.index_information()
  55. # if "gen_time_1" not in index_info:
  56. # print("开始创建索引...")
  57. # collection.create_index(
  58. # [("gen_time", ASCENDING)],
  59. # name="gen_time_1",
  60. # background=True
  61. # )
  62. # print("索引创建成功")
  63. # else:
  64. # print("索引已存在,跳过创建")
  65. # ------------------------- 容量控制 -------------------------
  66. # 使用更高效的计数方式
  67. if collection.estimated_document_count() >= 50:
  68. # 原子性删除操作
  69. if deleted := collection.find_one_and_delete(
  70. sort=[("gen_time", ASCENDING)],
  71. projection={"_id": 1, "model_name": 1, "gen_time": 1}
  72. ):
  73. self.logger.info(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
  74. # ------------------------- 数据插入 -------------------------
  75. with open(temp_path, 'rb') as f:
  76. result = collection.insert_one({
  77. "model_name": args['model_name'],
  78. "model_data": f.read(),
  79. "gen_time": args['gen_time'],
  80. "params": args['params'],
  81. "descr": args['descr']
  82. })
  83. self.logger.info(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  84. return str(result.inserted_id)
  85. except Exception as e:
  86. # ------------------------- 异常分类处理 -------------------------
  87. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
  88. self.logger.info(f"❌ {error_type} - 详细错误: {str(e)}")
  89. raise # 根据业务需求决定是否重新抛出
  90. finally:
  91. # ------------------------- 资源清理 -------------------------
  92. if client:
  93. client.close()
  94. if temp_path and os.path.exists(temp_path):
  95. try:
  96. os.remove(temp_path)
  97. except PermissionError:
  98. self.logger.info(f"⚠️ 临时文件清理失败: {temp_path}")
  99. def insert_scaler_model_into_mongo(self, feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO, args: Dict[str, Any]) -> str:
  100. """
  101. 将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档
  102. 参数:
  103. feature_scaler_bytes: BytesIO - 特征缩放器字节流
  104. scaled_target_bytes: BytesIO - 目标缩放器字节流
  105. args : dict - 包含以下键的字典:
  106. mongodb_database: 数据库名称
  107. scaler_table: 集合名称
  108. model_name: 关联模型名称
  109. gen_time: 生成时间(datetime对象)
  110. """
  111. # ------------------------- 参数校验 -------------------------
  112. required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'}
  113. if missing := required_keys - args.keys():
  114. raise ValueError(f"缺少必要参数: {missing}")
  115. # ------------------------- 配置解耦 -------------------------
  116. # 从环境变量获取连接信息(安全隔离凭证)
  117. mongodb_conn = os.getenv("MONGO_URI", "mongodb://root:Syjy*3377@localhost:27017/")
  118. # ------------------------- 输入验证 -------------------------
  119. for buf, name in [(feature_scaler_bytes, "特征缩放器"),
  120. (target_scaler_bytes, "目标缩放器")]:
  121. if not isinstance(buf, BytesIO):
  122. raise TypeError(f"{name} 必须为BytesIO类型")
  123. if buf.getbuffer().nbytes == 0:
  124. raise ValueError(f"{name} 字节流为空")
  125. client = None
  126. try:
  127. # ------------------------- 数据库连接 -------------------------
  128. client = MongoClient(mongodb_conn)
  129. db = client[args['mongodb_database']]
  130. collection = db[args['scaler_table']]
  131. # ------------------------- 索引维护 -------------------------
  132. # if "gen_time_1" not in collection.index_information():
  133. # collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
  134. # print("⏱️ 已创建时间排序索引")
  135. # ------------------------- 容量控制 -------------------------
  136. # 使用近似计数提升性能(误差在几十条内可接受)
  137. if collection.estimated_document_count() >= 50:
  138. # 原子性删除操作(保证事务完整性)
  139. if deleted := collection.find_one_and_delete(
  140. sort=[("gen_time", ASCENDING)],
  141. projection={"_id": 1, "model_name": 1, "gen_time": 1}
  142. ):
  143. self.logger.info(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
  144. # ------------------------- 数据插入 -------------------------
  145. # 确保字节流指针位置正确
  146. feature_scaler_bytes.seek(0)
  147. target_scaler_bytes.seek(0)
  148. result = collection.insert_one({
  149. "model_name": args['model_name'],
  150. "gen_time": args['gen_time'],
  151. "feature_scaler": feature_scaler_bytes.read(),
  152. "target_scaler": target_scaler_bytes.read()
  153. })
  154. self.logger.info(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  155. return str(result.inserted_id)
  156. except Exception as e:
  157. # ------------------------- 异常分类处理 -------------------------
  158. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误"
  159. self.logger.info(f"❌ {error_type}异常 - 详细错误: {str(e)}")
  160. raise # 根据业务需求决定是否重新抛出
  161. finally:
  162. # ------------------------- 资源清理 -------------------------
  163. if client:
  164. client.close()
  165. # 重置字节流指针(确保后续可复用)
  166. feature_scaler_bytes.seek(0)
  167. target_scaler_bytes.seek(0)
  168. def get_h5_model_from_mongo(self, args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[tf.keras.Model]:
  169. """
  170. 从MongoDB获取指定模型的最新版本
  171. 参数:
  172. args : dict - 包含以下键的字典:
  173. mongodb_database: 数据库名称
  174. model_table: 集合名称
  175. model_name: 要获取的模型名称
  176. custom_objects: dict - 自定义Keras对象字典
  177. 返回:
  178. tf.keras.Model - 加载成功的Keras模型
  179. """
  180. # ------------------------- 参数校验 -------------------------
  181. required_keys = {'mongodb_database', 'model_table', 'model_name'}
  182. if missing := required_keys - args.keys():
  183. raise ValueError(f"❌ 缺失必要参数: {missing}")
  184. # ------------------------- 环境配置 -------------------------
  185. mongo_uri = os.getenv("MONGO_URI", "mongodb://root:Syjy*3377@localhost:27017/")
  186. client = None
  187. tmp_file_path = None # 用于跟踪临时文件路径
  188. try:
  189. # ------------------------- 数据库连接 -------------------------
  190. client = MongoClient(
  191. mongo_uri,
  192. maxPoolSize=10, # 连接池优化
  193. socketTimeoutMS=5000
  194. )
  195. db = client[args['mongodb_database']]
  196. collection = db[args['model_table']]
  197. # ------------------------- 索引维护 -------------------------
  198. index_name = "model_gen_time_idx"
  199. if index_name not in collection.index_information():
  200. collection.create_index(
  201. [("model_name", 1), ("gen_time", DESCENDING)],
  202. name=index_name
  203. )
  204. self.logger.info("⏱️ 已创建复合索引")
  205. # ------------------------- 高效查询 -------------------------
  206. model_doc = collection.find_one(
  207. {"model_name": args['model_name']},
  208. sort=[('gen_time', DESCENDING)],
  209. projection={"model_data": 1, "gen_time": 1} # 获取必要字段
  210. )
  211. if not model_doc:
  212. self.logger.info(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
  213. return None
  214. # ------------------------- 内存优化加载 -------------------------
  215. if model_doc:
  216. model_data = model_doc['model_data'] # 获取模型的二进制数据
  217. # # 将二进制数据加载到 BytesIO 缓冲区
  218. # model_buffer = BytesIO(model_data)
  219. # # 确保指针在起始位置
  220. # model_buffer.seek(0)
  221. # # 从缓冲区加载模型
  222. # # 使用 h5py 和 BytesIO 从内存中加载模型
  223. # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
  224. # model = tf.keras.models.load_model(f, custom_objects=custom_objects)
  225. # 创建临时文件
  226. with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
  227. tmp_file.write(model_data)
  228. tmp_file_path = tmp_file.name # 获取临时文件路径
  229. # 从临时文件加载模型
  230. model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects)
  231. self.logger.info(f"{args['model_name']}模型成功从 MongoDB 加载!")
  232. return model
  233. except tf.errors.NotFoundError as e:
  234. self.logger.info(f"❌ 模型结构缺失关键组件: {str(e)}")
  235. raise RuntimeError("模型架构不完整") from e
  236. except Exception as e:
  237. self.logger.info(f"❌ 系统异常: {str(e)}")
  238. raise
  239. finally:
  240. # ------------------------- 资源清理 -------------------------
  241. if client:
  242. client.close()
  243. # 确保删除临时文件
  244. if tmp_file_path and os.path.exists(tmp_file_path):
  245. try:
  246. os.remove(tmp_file_path)
  247. self.logger.info(f"🧹 已清理临时文件: {tmp_file_path}")
  248. except Exception as cleanup_err:
  249. self.logger.info(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
  250. def get_keras_model_from_mongo(self,
  251. args: Dict[str, Any],
  252. custom_objects: Optional[Dict[str, Any]] = None
  253. ) -> Optional[tf.keras.Model]:
  254. """
  255. 从MongoDB获取指定模型的最新版本(支持Keras格式)
  256. 参数:
  257. args : dict - 包含以下键的字典:
  258. mongodb_database: 数据库名称
  259. model_table: 集合名称
  260. model_name: 要获取的模型名称
  261. custom_objects: dict - 自定义Keras对象字典
  262. 返回:
  263. tf.keras.Model - 加载成功的Keras模型
  264. """
  265. # ------------------------- 参数校验 -------------------------
  266. required_keys = {'mongodb_database', 'model_table', 'model_name'}
  267. if missing := required_keys - args.keys():
  268. raise ValueError(f"❌ 缺失必要参数: {missing}")
  269. # ------------------------- 环境配置 -------------------------
  270. mongo_uri = os.getenv("MONGO_URI", "mongodb://root:Syjy*3377@localhost:27017/")
  271. client = None
  272. tmp_file_path = None # 用于跟踪临时文件路径
  273. try:
  274. # ------------------------- 数据库连接 -------------------------
  275. client = MongoClient(
  276. mongo_uri,
  277. maxPoolSize=10,
  278. socketTimeoutMS=5000
  279. )
  280. db = client[args['mongodb_database']]
  281. collection = db[args['model_table']]
  282. # ------------------------- 索引维护 -------------------------
  283. # index_name = "model_gen_time_idx"
  284. # if index_name not in collection.index_information():
  285. # collection.create_index(
  286. # [("model_name", 1), ("gen_time", DESCENDING)],
  287. # name=index_name
  288. # )
  289. # print("⏱️ 已创建复合索引")
  290. # ------------------------- 高效查询 -------------------------
  291. model_doc = collection.find_one(
  292. {"model_name": args['model_name']},
  293. sort=[('gen_time', DESCENDING)],
  294. projection={"model_data": 1, "gen_time": 1, 'params':1}
  295. )
  296. if not model_doc:
  297. self.logger.info(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
  298. return None
  299. # ------------------------- 内存优化加载 -------------------------
  300. model_data = model_doc['model_data']
  301. model_params = model_doc['params']
  302. # 创建临时文件(自动删除)
  303. with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
  304. tmp_file.write(model_data)
  305. tmp_file_path = tmp_file.name # 记录文件路径
  306. # 从临时文件加载模型
  307. model = tf.keras.models.load_model(
  308. tmp_file_path,
  309. custom_objects=custom_objects
  310. )
  311. self.logger.info(f"{args['model_name']} 模型成功从 MongoDB 加载!")
  312. return model, model_params
  313. except tf.errors.NotFoundError as e:
  314. self.logger.info(f"❌ 模型结构缺失关键组件: {str(e)}")
  315. raise RuntimeError("模型架构不完整") from e
  316. except Exception as e:
  317. self.logger.info(f"❌ 系统异常: {str(e)}")
  318. raise
  319. finally:
  320. # ------------------------- 资源清理 -------------------------
  321. if client:
  322. client.close()
  323. # 确保删除临时文件
  324. if tmp_file_path and os.path.exists(tmp_file_path):
  325. try:
  326. os.remove(tmp_file_path)
  327. self.logger.info(f"🧹 已清理临时文件: {tmp_file_path}")
  328. except Exception as cleanup_err:
  329. self.logger.info(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
  330. def get_scaler_model_from_mongo(self, args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]:
  331. """
  332. 优化版特征缩放器加载函数 - 安全高效获取最新预处理模型
  333. 参数:
  334. args : 必须包含键:
  335. - mongodb_database: 数据库名称
  336. - scaler_table: 集合名称
  337. - model_name: 目标模型名称
  338. only_feature_scaler : 是否仅返回特征缩放器
  339. 返回:
  340. 单个缩放器对象或(feature_scaler, target_scaler)元组
  341. 异常:
  342. ValueError : 参数缺失或类型错误
  343. RuntimeError : 数据操作异常
  344. """
  345. # ------------------------- 参数校验 -------------------------
  346. required_keys = {'mongodb_database', 'scaler_table', 'model_name'}
  347. if missing := required_keys - args.keys():
  348. raise ValueError(f"❌ 缺失必要参数: {missing}")
  349. # ------------------------- 环境配置 -------------------------
  350. mongo_uri = os.getenv("MONGO_URI", "mongodb://root:Syjy*3377@localhost:27017/")
  351. client = None
  352. try:
  353. # ------------------------- 数据库连接 -------------------------
  354. client = MongoClient(
  355. mongo_uri,
  356. maxPoolSize=20, # 连接池上限
  357. socketTimeoutMS=3000, # 3秒超时
  358. serverSelectionTimeoutMS=5000 # 5秒服务器选择超时
  359. )
  360. db = client[args['mongodb_database']]
  361. collection = db[args['scaler_table']]
  362. # ------------------------- 索引维护 -------------------------
  363. # index_name = "model_gen_time_idx"
  364. # if index_name not in collection.index_information():
  365. # collection.create_index(
  366. # [("model_name", 1), ("gen_time", DESCENDING)],
  367. # name=index_name,
  368. # background=True # 后台构建避免阻塞
  369. # )
  370. # print("⏱️ 已创建特征缩放器复合索引")
  371. # ------------------------- 高效查询 -------------------------
  372. scaler_doc = collection.find_one(
  373. {"model_name": args['model_name']},
  374. sort=[('gen_time', DESCENDING)],
  375. projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1}
  376. )
  377. if not scaler_doc:
  378. raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录")
  379. # ------------------------- 反序列化处理 -------------------------
  380. def load_scaler(data: bytes) -> object:
  381. """安全加载序列化对象"""
  382. with BytesIO(data) as buffer:
  383. buffer.seek(0) # 确保指针复位
  384. try:
  385. return joblib.load(buffer)
  386. except joblib.UnpicklingError as e:
  387. raise RuntimeError("反序列化失败 (可能版本不兼容)") from e
  388. # 特征缩放器加载
  389. feature_data = scaler_doc["feature_scaler"]
  390. if not isinstance(feature_data, bytes):
  391. raise RuntimeError("特征缩放器数据格式异常")
  392. feature_scaler = load_scaler(feature_data)
  393. if only_feature_scaler:
  394. return feature_scaler
  395. # 目标缩放器加载
  396. target_data = scaler_doc["target_scaler"]
  397. if not isinstance(target_data, bytes):
  398. raise RuntimeError("目标缩放器数据格式异常")
  399. target_scaler = load_scaler(target_data)
  400. self.logger.info(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})")
  401. return feature_scaler, target_scaler
  402. except PyMongoError as e:
  403. raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e
  404. except RuntimeError as e:
  405. raise RuntimeError(f"🔌 mongo操作失败: {str(e)}") from e# 直接传递已封装的异常
  406. except Exception as e:
  407. raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
  408. finally:
  409. # ------------------------- 资源清理 -------------------------
  410. if client:
  411. client.close()