database_dml_koi.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. import pymongo
  2. from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING
  3. from pymongo.errors import PyMongoError
  4. import pandas as pd
  5. from sqlalchemy import create_engine
  6. import pickle
  7. from io import BytesIO
  8. import joblib
  9. import h5py, os, io
  10. import tensorflow as tf
  11. from typing import Dict, Any, Optional, Union, Tuple
  12. import tempfile
  13. def get_data_from_mongo(args):
  14. mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
  15. mongodb_database = args['mongodb_database']
  16. mongodb_read_table = args['mongodb_read_table']
  17. query_dict = {}
  18. if 'timeBegin' in args.keys():
  19. timeBegin = args['timeBegin']
  20. query_dict.update({"$gte": timeBegin})
  21. if 'timeEnd' in args.keys():
  22. timeEnd = args['timeEnd']
  23. query_dict.update({"$lte": timeEnd})
  24. client = MongoClient(mongodb_connection)
  25. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  26. db = client[mongodb_database]
  27. collection = db[mongodb_read_table] # 集合名称
  28. if len(query_dict) != 0:
  29. query = {"dateTime": query_dict}
  30. cursor = collection.find(query)
  31. else:
  32. cursor = collection.find()
  33. data = list(cursor)
  34. df = pd.DataFrame(data)
  35. # 4. 删除 _id 字段(可选)
  36. if '_id' in df.columns:
  37. df = df.drop(columns=['_id'])
  38. client.close()
  39. return df
  40. def get_df_list_from_mongo(args):
  41. mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',')
  42. df_list = []
  43. client = MongoClient(mongodb_connection)
  44. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  45. db = client[mongodb_database]
  46. for table in mongodb_read_table:
  47. collection = db[table] # 集合名称
  48. data_from_db = collection.find() # 这会返回一个游标(cursor)
  49. # 将游标转换为列表,并创建 pandas DataFrame
  50. df = pd.DataFrame(list(data_from_db))
  51. if '_id' in df.columns:
  52. df = df.drop(columns=['_id'])
  53. df_list.append(df)
  54. client.close()
  55. return df_list
  56. def insert_data_into_mongo(res_df, args):
  57. """
  58. 插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
  59. 参数:
  60. - res_df: 要插入的 DataFrame 数据
  61. - args: 包含 MongoDB 数据库和集合名称的字典
  62. - overwrite: 布尔值,True 表示覆盖,False 表示追加
  63. - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
  64. """
  65. mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
  66. mongodb_database = args['mongodb_database']
  67. mongodb_write_table = args['mongodb_write_table']
  68. overwrite = 1
  69. update_keys = None
  70. if 'overwrite' in args.keys():
  71. overwrite = int(args['overwrite'])
  72. if 'update_keys' in args.keys():
  73. update_keys = args['update_keys'].split(',')
  74. client = MongoClient(mongodb_connection)
  75. db = client[mongodb_database]
  76. collection = db[mongodb_write_table]
  77. # 覆盖模式:删除现有集合
  78. if overwrite:
  79. if mongodb_write_table in db.list_collection_names():
  80. collection.drop()
  81. print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!")
  82. # 将 DataFrame 转为字典格式
  83. data_dict = res_df.to_dict("records") # 每一行作为一个字典
  84. # 如果没有数据,直接返回
  85. if not data_dict:
  86. print("No data to insert.")
  87. return
  88. # 如果指定了 update_keys,则执行 upsert(更新或插入)
  89. if update_keys and not overwrite:
  90. operations = []
  91. for record in data_dict:
  92. # 构建查询条件,用于匹配要更新的文档
  93. query = {key: record[key] for key in update_keys}
  94. operations.append(UpdateOne(query, {'$set': record}, upsert=True))
  95. # 批量执行更新/插入操作
  96. if operations:
  97. result = collection.bulk_write(operations)
  98. print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
  99. else:
  100. # 追加模式:直接插入新数据
  101. collection.insert_many(data_dict)
  102. print("Data inserted successfully!")
  103. def get_data_fromMysql(params):
  104. mysql_conn = params['mysql_conn']
  105. query_sql = params['query_sql']
  106. #数据库读取实测气象
  107. engine = create_engine(f"mysql+pymysql://{mysql_conn}")
  108. # 定义SQL查询
  109. env_df = pd.read_sql_query(query_sql, engine)
  110. return env_df
  111. def insert_pickle_model_into_mongo(model, args):
  112. mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
  113. args['mongodb_database'], args['mongodb_write_table'], args['model_name']
  114. client = MongoClient(mongodb_connection)
  115. db = client[mongodb_database]
  116. # 序列化模型
  117. model_bytes = pickle.dumps(model)
  118. model_data = {
  119. 'model_name': model_name,
  120. 'model': model_bytes, # 将模型字节流存入数据库
  121. }
  122. print('Training completed!')
  123. if mongodb_write_table in db.list_collection_names():
  124. db[mongodb_write_table].drop()
  125. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  126. collection = db[mongodb_write_table] # 集合名称
  127. collection.insert_one(model_data)
  128. print("model inserted successfully!")
  129. def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any]) -> str:
  130. """
  131. 将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型
  132. 参数:
  133. model : keras模型 - 训练好的Keras模型
  134. args : dict - 包含以下键的字典:
  135. mongodb_database: 数据库名称
  136. model_table: 集合名称
  137. model_name: 模型名称
  138. gen_time: 模型生成时间(datetime对象)
  139. params: 模型参数(JSON可序列化对象)
  140. descr: 模型描述文本
  141. """
  142. # ------------------------- 参数校验 -------------------------
  143. required_keys = {'mongodb_database', 'model_table', 'model_name',
  144. 'gen_time', 'params', 'descr'}
  145. if missing := required_keys - args.keys():
  146. raise ValueError(f"缺少必要参数: {missing}")
  147. # ------------------------- 配置解耦 -------------------------
  148. # 从环境变量获取连接信息(更安全)
  149. mongodb_connection = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
  150. # ------------------------- 资源初始化 -------------------------
  151. fd, temp_path = None, None
  152. client = None
  153. try:
  154. # ------------------------- 临时文件处理 -------------------------
  155. fd, temp_path = tempfile.mkstemp(suffix='.h5')
  156. os.close(fd) # 立即释放文件锁
  157. # ------------------------- 模型保存 -------------------------
  158. try:
  159. model.save(temp_path, save_format='h5')
  160. except Exception as e:
  161. raise RuntimeError(f"模型保存失败: {str(e)}") from e
  162. # ------------------------- 数据库连接 -------------------------
  163. client = MongoClient(mongodb_connection)
  164. db = client[args['mongodb_database']]
  165. collection = db[args['model_table']]
  166. # ------------------------- 索引检查 -------------------------
  167. if "gen_time_1" not in collection.index_information():
  168. collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
  169. print("已创建时间索引")
  170. # ------------------------- 容量控制 -------------------------
  171. # 使用更高效的计数方式
  172. if collection.estimated_document_count() >= 50:
  173. # 原子性删除操作
  174. if deleted := collection.find_one_and_delete(
  175. sort=[("gen_time", ASCENDING)],
  176. projection={"_id": 1, "model_name": 1, "gen_time": 1}
  177. ):
  178. print(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
  179. # ------------------------- 数据插入 -------------------------
  180. with open(temp_path, 'rb') as f:
  181. result = collection.insert_one({
  182. "model_name": args['model_name'],
  183. "model_data": f.read(),
  184. "gen_time": args['gen_time'],
  185. "params": args['params'],
  186. "descr": args['descr']
  187. })
  188. print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  189. return str(result.inserted_id)
  190. except Exception as e:
  191. # ------------------------- 异常分类处理 -------------------------
  192. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
  193. print(f"❌ {error_type} - 详细错误: {str(e)}")
  194. raise # 根据业务需求决定是否重新抛出
  195. finally:
  196. # ------------------------- 资源清理 -------------------------
  197. if client:
  198. client.close()
  199. if temp_path and os.path.exists(temp_path):
  200. try:
  201. os.remove(temp_path)
  202. except PermissionError:
  203. print(f"⚠️ 临时文件清理失败: {temp_path}")
  204. def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO, args: Dict[str, Any]) -> str:
  205. """
  206. 将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档
  207. 参数:
  208. feature_scaler_bytes: BytesIO - 特征缩放器字节流
  209. scaled_target_bytes: BytesIO - 目标缩放器字节流
  210. args : dict - 包含以下键的字典:
  211. mongodb_database: 数据库名称
  212. scaler_table: 集合名称
  213. model_name: 关联模型名称
  214. gen_time: 生成时间(datetime对象)
  215. """
  216. # ------------------------- 参数校验 -------------------------
  217. required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'}
  218. if missing := required_keys - args.keys():
  219. raise ValueError(f"缺少必要参数: {missing}")
  220. # ------------------------- 配置解耦 -------------------------
  221. # 从环境变量获取连接信息(安全隔离凭证)
  222. mongodb_conn = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
  223. # ------------------------- 输入验证 -------------------------
  224. for buf, name in [(feature_scaler_bytes, "特征缩放器"),
  225. (target_scaler_bytes, "目标缩放器")]:
  226. if not isinstance(buf, BytesIO):
  227. raise TypeError(f"{name} 必须为BytesIO类型")
  228. if buf.getbuffer().nbytes == 0:
  229. raise ValueError(f"{name} 字节流为空")
  230. client = None
  231. try:
  232. # ------------------------- 数据库连接 -------------------------
  233. client = MongoClient(mongodb_conn)
  234. db = client[args['mongodb_database']]
  235. collection = db[args['scaler_table']]
  236. # ------------------------- 索引维护 -------------------------
  237. if "gen_time_1" not in collection.index_information():
  238. collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
  239. print("⏱️ 已创建时间排序索引")
  240. # ------------------------- 容量控制 -------------------------
  241. # 使用近似计数提升性能(误差在几十条内可接受)
  242. if collection.estimated_document_count() >= 50:
  243. # 原子性删除操作(保证事务完整性)
  244. if deleted := collection.find_one_and_delete(
  245. sort=[("gen_time", ASCENDING)],
  246. projection={"_id": 1, "model_name": 1, "gen_time": 1}
  247. ):
  248. print(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
  249. # ------------------------- 数据插入 -------------------------
  250. # 确保字节流指针位置正确
  251. feature_scaler_bytes.seek(0)
  252. target_scaler_bytes.seek(0)
  253. result = collection.insert_one({
  254. "model_name": args['model_name'],
  255. "gen_time": args['gen_time'],
  256. "feature_scaler": feature_scaler_bytes.read(),
  257. "target_scaler": target_scaler_bytes.read()
  258. })
  259. print(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  260. return str(result.inserted_id)
  261. except Exception as e:
  262. # ------------------------- 异常分类处理 -------------------------
  263. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误"
  264. print(f"❌ {error_type}异常 - 详细错误: {str(e)}")
  265. raise # 根据业务需求决定是否重新抛出
  266. finally:
  267. # ------------------------- 资源清理 -------------------------
  268. if client:
  269. client.close()
  270. # 重置字节流指针(确保后续可复用)
  271. feature_scaler_bytes.seek(0)
  272. target_scaler_bytes.seek(0)
  273. def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[tf.keras.Model]:
  274. """
  275. 从MongoDB获取指定模型的最新版本
  276. 参数:
  277. args : dict - 包含以下键的字典:
  278. mongodb_database: 数据库名称
  279. model_table: 集合名称
  280. model_name: 要获取的模型名称
  281. custom_objects: dict - 自定义Keras对象字典
  282. 返回:
  283. tf.keras.Model - 加载成功的Keras模型
  284. """
  285. # ------------------------- 参数校验 -------------------------
  286. required_keys = {'mongodb_database', 'model_table', 'model_name'}
  287. if missing := required_keys - args.keys():
  288. raise ValueError(f"❌ 缺失必要参数: {missing}")
  289. # ------------------------- 环境配置 -------------------------
  290. mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
  291. client = None
  292. try:
  293. # ------------------------- 数据库连接 -------------------------
  294. client = MongoClient(
  295. mongo_uri,
  296. maxPoolSize=10, # 连接池优化
  297. socketTimeoutMS=5000
  298. )
  299. db = client[args['mongodb_database']]
  300. collection = db[args['model_table']]
  301. # ------------------------- 索引维护 -------------------------
  302. index_name = "model_gen_time_idx"
  303. if index_name not in collection.index_information():
  304. collection.create_index(
  305. [("model_name", 1), ("gen_time", DESCENDING)],
  306. name=index_name
  307. )
  308. print("⏱️ 已创建复合索引")
  309. # ------------------------- 高效查询 -------------------------
  310. model_doc = collection.find_one(
  311. {"model_name": args['model_name']},
  312. sort=[('gen_time', DESCENDING)],
  313. projection={"model_data": 1, "gen_time": 1} # 获取必要字段
  314. )
  315. if not model_doc:
  316. print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
  317. return None
  318. # ------------------------- 内存优化加载 -------------------------
  319. if model_doc:
  320. model_data = model_doc['model_data'] # 获取模型的二进制数据
  321. # # 将二进制数据加载到 BytesIO 缓冲区
  322. # model_buffer = BytesIO(model_data)
  323. # # 确保指针在起始位置
  324. # model_buffer.seek(0)
  325. # # 从缓冲区加载模型
  326. # # 使用 h5py 和 BytesIO 从内存中加载模型
  327. # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
  328. # model = tf.keras.models.load_model(f, custom_objects=custom_objects)
  329. # 创建临时文件
  330. with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp_file:
  331. tmp_file.write(model_data)
  332. tmp_file_path = tmp_file.name # 获取临时文件路径
  333. # 从临时文件加载模型
  334. model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects)
  335. print(f"{args['model_name']}模型成功从 MongoDB 加载!")
  336. return model
  337. except tf.errors.NotFoundError as e:
  338. print(f"❌ 模型结构缺失关键组件: {str(e)}")
  339. raise RuntimeError("模型架构不完整") from e
  340. except Exception as e:
  341. print(f"❌ 系统异常: {str(e)}")
  342. raise
  343. finally:
  344. # ------------------------- 资源清理 -------------------------
  345. if client:
  346. client.close()
  347. def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]:
  348. """
  349. 优化版特征缩放器加载函数 - 安全高效获取最新预处理模型
  350. 参数:
  351. args : 必须包含键:
  352. - mongodb_database: 数据库名称
  353. - scaler_table: 集合名称
  354. - model_name: 目标模型名称
  355. only_feature_scaler : 是否仅返回特征缩放器
  356. 返回:
  357. 单个缩放器对象或(feature_scaler, target_scaler)元组
  358. 异常:
  359. ValueError : 参数缺失或类型错误
  360. RuntimeError : 数据操作异常
  361. """
  362. # ------------------------- 参数校验 -------------------------
  363. required_keys = {'mongodb_database', 'scaler_table', 'model_name'}
  364. if missing := required_keys - args.keys():
  365. raise ValueError(f"❌ 缺失必要参数: {missing}")
  366. # ------------------------- 环境配置 -------------------------
  367. mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
  368. client = None
  369. try:
  370. # ------------------------- 数据库连接 -------------------------
  371. client = MongoClient(
  372. mongo_uri,
  373. maxPoolSize=20, # 连接池上限
  374. socketTimeoutMS=3000, # 3秒超时
  375. serverSelectionTimeoutMS=5000 # 5秒服务器选择超时
  376. )
  377. db = client[args['mongodb_database']]
  378. collection = db[args['scaler_table']]
  379. # ------------------------- 索引维护 -------------------------
  380. index_name = "model_gen_time_idx"
  381. if index_name not in collection.index_information():
  382. collection.create_index(
  383. [("model_name", 1), ("gen_time", DESCENDING)],
  384. name=index_name,
  385. background=True # 后台构建避免阻塞
  386. )
  387. print("⏱️ 已创建特征缩放器复合索引")
  388. # ------------------------- 高效查询 -------------------------
  389. scaler_doc = collection.find_one(
  390. {"model_name": args['model_name']},
  391. sort=[('gen_time', DESCENDING)],
  392. projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1}
  393. )
  394. if not scaler_doc:
  395. raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录")
  396. # ------------------------- 反序列化处理 -------------------------
  397. def load_scaler(data: bytes) -> object:
  398. """安全加载序列化对象"""
  399. with BytesIO(data) as buffer:
  400. buffer.seek(0) # 确保指针复位
  401. try:
  402. return joblib.load(buffer)
  403. except joblib.UnpicklingError as e:
  404. raise RuntimeError("反序列化失败 (可能版本不兼容)") from e
  405. # 特征缩放器加载
  406. feature_data = scaler_doc["feature_scaler"]
  407. if not isinstance(feature_data, bytes):
  408. raise RuntimeError("特征缩放器数据格式异常")
  409. feature_scaler = load_scaler(feature_data)
  410. if only_feature_scaler:
  411. return feature_scaler
  412. # 目标缩放器加载
  413. target_data = scaler_doc["target_scaler"]
  414. if not isinstance(target_data, bytes):
  415. raise RuntimeError("目标缩放器数据格式异常")
  416. target_scaler = load_scaler(target_data)
  417. print(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})")
  418. return feature_scaler, target_scaler
  419. except PyMongoError as e:
  420. raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e
  421. except RuntimeError as e:
  422. raise RuntimeError(f"🔌 mongo操作失败: {str(e)}") from e# 直接传递已封装的异常
  423. except Exception as e:
  424. raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
  425. finally:
  426. # ------------------------- 资源清理 -------------------------
  427. # if client:
  428. # client.close()
  429. pass