database_dml.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. import pymongo
  2. from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING
  3. from pymongo.errors import PyMongoError
  4. import pandas as pd
  5. from sqlalchemy import create_engine
  6. import pickle
  7. from io import BytesIO
  8. import joblib
  9. import json
  10. import tensorflow as tf
  11. import os
  12. import tempfile
  13. import jaydebeapi
  14. import toml
  15. from typing import Dict, Any, Optional, Union, Tuple
  16. from datetime import datetime, timedelta
  17. # 读取 toml 配置文件
  18. current_dir = os.path.dirname(os.path.abspath(__file__))
  19. with open(os.path.join(current_dir, 'database.toml'), 'r', encoding='utf-8') as f:
  20. config = toml.load(f) # 只读的全局配置
  21. jar_file = os.path.join(current_dir, 'jar/hive-jdbc-standalone.jar')
  22. def get_data_from_mongo(args):
  23. # 获取 hive 配置部分
  24. mongodb_connection = config['mongodb']['mongodb_connection']
  25. mongodb_database = args['mongodb_database']
  26. mongodb_read_table = args['mongodb_read_table']
  27. query_dict = {}
  28. if 'timeBegin' in args.keys():
  29. timeBegin = args['timeBegin']
  30. query_dict.update({"$gte": timeBegin})
  31. if 'timeEnd' in args.keys():
  32. timeEnd = args['timeEnd']
  33. query_dict.update({"$lte": timeEnd})
  34. client = MongoClient(mongodb_connection)
  35. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  36. db = client[mongodb_database]
  37. collection = db[mongodb_read_table] # 集合名称
  38. if len(query_dict) != 0:
  39. query = {"dateTime": query_dict}
  40. cursor = collection.find(query)
  41. else:
  42. cursor = collection.find()
  43. data = list(cursor)
  44. df = pd.DataFrame(data)
  45. # 4. 删除 _id 字段(可选)
  46. if '_id' in df.columns:
  47. df = df.drop(columns=['_id'])
  48. client.close()
  49. return df
  50. def get_df_list_from_mongo(args):
  51. # 获取 hive 配置部分
  52. mongodb_connection = config['mongodb']['mongodb_connection']
  53. mongodb_database, mongodb_read_table = args['mongodb_database'], args['mongodb_read_table'].split(',')
  54. df_list = []
  55. client = MongoClient(mongodb_connection)
  56. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  57. db = client[mongodb_database]
  58. for table in mongodb_read_table:
  59. collection = db[table] # 集合名称
  60. data_from_db = collection.find() # 这会返回一个游标(cursor)
  61. # 将游标转换为列表,并创建 pandas DataFrame
  62. df = pd.DataFrame(list(data_from_db))
  63. if '_id' in df.columns:
  64. df = df.drop(columns=['_id'])
  65. df_list.append(df)
  66. client.close()
  67. return df_list
  68. def insert_data_into_mongo(res_df, args):
  69. """
  70. 插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
  71. 参数:
  72. - res_df: 要插入的 DataFrame 数据
  73. - args: 包含 MongoDB 数据库和集合名称的字典
  74. - overwrite: 布尔值,True 表示覆盖,False 表示追加
  75. - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
  76. """
  77. # 获取 hive 配置部分
  78. mongodb_connection = config['mongodb']['mongodb_connection']
  79. mongodb_database = args['mongodb_database']
  80. mongodb_write_table = args['mongodb_write_table']
  81. overwrite = 1
  82. update_keys = None
  83. if 'overwrite' in args.keys():
  84. overwrite = int(args['overwrite'])
  85. if 'update_keys' in args.keys():
  86. update_keys = args['update_keys'].split(',')
  87. client = MongoClient(mongodb_connection)
  88. db = client[mongodb_database]
  89. collection = db[mongodb_write_table]
  90. # 覆盖模式:删除现有集合
  91. if overwrite:
  92. if mongodb_write_table in db.list_collection_names():
  93. collection.drop()
  94. print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!")
  95. # 将 DataFrame 转为字典格式
  96. data_dict = res_df.to_dict("records") # 每一行作为一个字典
  97. # 如果没有数据,直接返回
  98. if not data_dict:
  99. print("No data to insert.")
  100. return
  101. # 如果指定了 update_keys,则执行 upsert(更新或插入)
  102. if update_keys and not overwrite:
  103. operations = []
  104. for record in data_dict:
  105. # 构建查询条件,用于匹配要更新的文档
  106. query = {key: record[key] for key in update_keys}
  107. operations.append(UpdateOne(query, {'$set': record}, upsert=True))
  108. # 批量执行更新/插入操作
  109. if operations:
  110. result = collection.bulk_write(operations)
  111. client.close()
  112. print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
  113. else:
  114. # 追加模式:直接插入新数据
  115. collection.insert_many(data_dict)
  116. client.close()
  117. print("Data inserted successfully!")
  118. def get_data_fromMysql(params):
  119. mysql_conn = params['mysql_conn']
  120. query_sql = params['query_sql']
  121. # 数据库读取实测气象
  122. engine = create_engine(f"mysql+pymysql://{mysql_conn}")
  123. # 定义SQL查询
  124. with engine.connect() as conn:
  125. df = pd.read_sql_query(query_sql, conn)
  126. return df
  127. def insert_pickle_model_into_mongo(model, args):
  128. # 获取 hive 配置部分
  129. mongodb_connection = config['mongodb']['mongodb_connection']
  130. mongodb_database, mongodb_write_table, model_name = args['mongodb_database'], args['mongodb_write_table'], args[
  131. 'model_name']
  132. client = MongoClient(mongodb_connection)
  133. db = client[mongodb_database]
  134. # 序列化模型
  135. model_bytes = pickle.dumps(model)
  136. model_data = {
  137. 'model_name': model_name,
  138. 'model': model_bytes, # 将模型字节流存入数据库
  139. }
  140. print('Training completed!')
  141. if mongodb_write_table in db.list_collection_names():
  142. db[mongodb_write_table].drop()
  143. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  144. collection = db[mongodb_write_table] # 集合名称
  145. collection.insert_one(model_data)
  146. client.close()
  147. print("model inserted successfully!")
  148. def get_pickle_model_from_mongo(args):
  149. mongodb_connection = config['mongodb']['mongodb_connection']
  150. mongodb_database, mongodb_model_table, model_name = args['mongodb_database'], args['mongodb_model_table'], args['model_name']
  151. client = MongoClient(mongodb_connection)
  152. db = client[mongodb_database]
  153. collection = db[mongodb_model_table]
  154. model_data = collection.find_one({"model_name": model_name})
  155. if model_data is not None:
  156. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  157. # 反序列化模型
  158. model = pickle.loads(model_binary)
  159. return model
  160. else:
  161. return None
  162. def insert_h5_model_into_mongo(model, feature_scaler_bytes, target_scaler_bytes, args):
  163. # 获取 hive 配置部分
  164. mongodb_connection = config['mongodb']['mongodb_connection']
  165. mongodb_database, scaler_table, model_table, model_name = args['mongodb_database'], args['scaler_table'], args[
  166. 'model_table'], args['model_name']
  167. client = MongoClient(mongodb_connection)
  168. db = client[mongodb_database]
  169. if scaler_table in db.list_collection_names():
  170. db[scaler_table].drop()
  171. print(f"Collection '{scaler_table} already exist, deleted successfully!")
  172. collection = db[scaler_table] # 集合名称
  173. # Save the scalers in MongoDB as binary data
  174. collection.insert_one({
  175. "feature_scaler": feature_scaler_bytes.read(),
  176. "target_scaler": target_scaler_bytes.read()
  177. })
  178. print("scaler_model inserted successfully!")
  179. if model_table in db.list_collection_names():
  180. db[model_table].drop()
  181. print(f"Collection '{model_table} already exist, deleted successfully!")
  182. model_table = db[model_table]
  183. fd, temp_path = None, None
  184. client = None
  185. try:
  186. # ------------------------- 临时文件处理 -------------------------
  187. fd, temp_path = tempfile.mkstemp(suffix='.keras')
  188. os.close(fd) # 立即释放文件锁
  189. # ------------------------- 模型保存 -------------------------
  190. try:
  191. model.save(temp_path) # 不指定save_format,默认使用keras新格式
  192. except Exception as e:
  193. raise RuntimeError(f"模型保存失败: {str(e)}") from e
  194. # ------------------------- 数据插入 -------------------------
  195. with open(temp_path, 'rb') as f:
  196. result = model_table.insert_one({
  197. "model_name": args['model_name'],
  198. "model_data": f.read(),
  199. })
  200. print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  201. return str(result.inserted_id)
  202. except Exception as e:
  203. # ------------------------- 异常分类处理 -------------------------
  204. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
  205. print(f"❌ {error_type} - 详细错误: {str(e)}")
  206. raise # 根据业务需求决定是否重新抛出
  207. finally:
  208. # ------------------------- 资源清理 -------------------------
  209. if client:
  210. client.close()
  211. if temp_path and os.path.exists(temp_path):
  212. try:
  213. os.remove(temp_path)
  214. except PermissionError:
  215. print(f"⚠️ 临时文件清理失败: {temp_path}")
  216. def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any]) -> str:
  217. """
  218. 将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型
  219. 参数:
  220. model : keras模型 - 训练好的Keras模型
  221. args : dict - 包含以下键的字典:
  222. mongodb_database: 数据库名称
  223. model_table: 集合名称
  224. model_name: 模型名称
  225. gen_time: 模型生成时间(datetime对象)
  226. params: 模型参数(JSON可序列化对象)
  227. descr: 模型描述文本
  228. """
  229. # ------------------------- 参数校验 -------------------------
  230. required_keys = {'mongodb_database', 'model_table', 'model_name',
  231. 'gen_time', 'params', 'descr'}
  232. if missing := required_keys - args.keys():
  233. raise ValueError(f"缺少必要参数: {missing}")
  234. # ------------------------- 配置解耦 -------------------------
  235. # 从环境变量获取连接信息(更安全)
  236. mongodb_connection = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  237. # ------------------------- 资源初始化 -------------------------
  238. fd, temp_path = None, None
  239. client = None
  240. try:
  241. # ------------------------- 临时文件处理 -------------------------
  242. fd, temp_path = tempfile.mkstemp(suffix='.keras')
  243. os.close(fd) # 立即释放文件锁
  244. # ------------------------- 模型保存 -------------------------
  245. try:
  246. model.save(temp_path) # 不指定save_format,默认使用keras新格式
  247. except Exception as e:
  248. raise RuntimeError(f"模型保存失败: {str(e)}") from e
  249. # ------------------------- 数据库连接 -------------------------
  250. client = MongoClient(mongodb_connection)
  251. db = client[args['mongodb_database']]
  252. collection = db[args['model_table']]
  253. # ------------------------- 索引检查 -------------------------
  254. # index_info = collection.index_information()
  255. # if "gen_time_1" not in index_info:
  256. # print("开始创建索引...")
  257. # collection.create_index(
  258. # [("gen_time", ASCENDING)],
  259. # name="gen_time_1",
  260. # background=True
  261. # )
  262. # print("索引创建成功")
  263. # else:
  264. # print("索引已存在,跳过创建")
  265. # ------------------------- 容量控制 -------------------------
  266. # 使用更高效的计数方式
  267. if collection.estimated_document_count() >= 50:
  268. # 原子性删除操作
  269. if deleted := collection.find_one_and_delete(
  270. sort=[("gen_time", ASCENDING)],
  271. projection={"_id": 1, "model_name": 1, "gen_time": 1}
  272. ):
  273. print(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
  274. # ------------------------- 数据插入 -------------------------
  275. with open(temp_path, 'rb') as f:
  276. result = collection.insert_one({
  277. "model_name": args['model_name'],
  278. "model_data": f.read(),
  279. "gen_time": args['gen_time'],
  280. "params": args['params'],
  281. "descr": args['descr']
  282. })
  283. print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  284. return str(result.inserted_id)
  285. except Exception as e:
  286. # ------------------------- 异常分类处理 -------------------------
  287. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
  288. print(f"❌ {error_type} - 详细错误: {str(e)}")
  289. raise # 根据业务需求决定是否重新抛出
  290. finally:
  291. # ------------------------- 资源清理 -------------------------
  292. if client:
  293. client.close()
  294. if temp_path and os.path.exists(temp_path):
  295. try:
  296. os.remove(temp_path)
  297. except PermissionError:
  298. print(f"⚠️ 临时文件清理失败: {temp_path}")
  299. def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO,
  300. args: Dict[str, Any]) -> str:
  301. """
  302. 将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档
  303. 参数:
  304. feature_scaler_bytes: BytesIO - 特征缩放器字节流
  305. scaled_target_bytes: BytesIO - 目标缩放器字节流
  306. args : dict - 包含以下键的字典:
  307. mongodb_database: 数据库名称
  308. scaler_table: 集合名称
  309. model_name: 关联模型名称
  310. gen_time: 生成时间(datetime对象)
  311. """
  312. # ------------------------- 参数校验 -------------------------
  313. required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'}
  314. if missing := required_keys - args.keys():
  315. raise ValueError(f"缺少必要参数: {missing}")
  316. # ------------------------- 配置解耦 -------------------------
  317. # 从环境变量获取连接信息(安全隔离凭证)
  318. mongodb_conn = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  319. # ------------------------- 输入验证 -------------------------
  320. for buf, name in [(feature_scaler_bytes, "特征缩放器"),
  321. (target_scaler_bytes, "目标缩放器")]:
  322. if not isinstance(buf, BytesIO):
  323. raise TypeError(f"{name} 必须为BytesIO类型")
  324. if buf.getbuffer().nbytes == 0:
  325. raise ValueError(f"{name} 字节流为空")
  326. client = None
  327. try:
  328. # ------------------------- 数据库连接 -------------------------
  329. client = MongoClient(mongodb_conn)
  330. db = client[args['mongodb_database']]
  331. collection = db[args['scaler_table']]
  332. # ------------------------- 索引维护 -------------------------
  333. # if "gen_time_1" not in collection.index_information():
  334. # collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
  335. # print("⏱️ 已创建时间排序索引")
  336. # ------------------------- 容量控制 -------------------------
  337. # 使用近似计数提升性能(误差在几十条内可接受)
  338. if collection.estimated_document_count() >= 50:
  339. # 原子性删除操作(保证事务完整性)
  340. if deleted := collection.find_one_and_delete(
  341. sort=[("gen_time", ASCENDING)],
  342. projection={"_id": 1, "model_name": 1, "gen_time": 1}
  343. ):
  344. print(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
  345. # ------------------------- 数据插入 -------------------------
  346. # 确保字节流指针位置正确
  347. feature_scaler_bytes.seek(0)
  348. target_scaler_bytes.seek(0)
  349. result = collection.insert_one({
  350. "model_name": args['model_name'],
  351. "gen_time": args['gen_time'],
  352. "feature_scaler": feature_scaler_bytes.read(),
  353. "target_scaler": target_scaler_bytes.read()
  354. })
  355. print(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  356. return str(result.inserted_id)
  357. except Exception as e:
  358. # ------------------------- 异常分类处理 -------------------------
  359. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误"
  360. print(f"❌ {error_type}异常 - 详细错误: {str(e)}")
  361. raise # 根据业务需求决定是否重新抛出
  362. finally:
  363. # ------------------------- 资源清理 -------------------------
  364. if client:
  365. client.close()
  366. # 重置字节流指针(确保后续可复用)
  367. feature_scaler_bytes.seek(0)
  368. target_scaler_bytes.seek(0)
  369. def get_h5_model_from_mongo(args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[
  370. tf.keras.Model]:
  371. """
  372. 从MongoDB获取指定模型的最新版本
  373. 参数:
  374. args : dict - 包含以下键的字典:
  375. mongodb_database: 数据库名称
  376. model_table: 集合名称
  377. model_name: 要获取的模型名称
  378. custom_objects: dict - 自定义Keras对象字典
  379. 返回:
  380. tf.keras.Model - 加载成功的Keras模型
  381. """
  382. # ------------------------- 参数校验 -------------------------
  383. required_keys = {'mongodb_database', 'model_table', 'model_name'}
  384. if missing := required_keys - args.keys():
  385. raise ValueError(f"❌ 缺失必要参数: {missing}")
  386. # ------------------------- 环境配置 -------------------------
  387. mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  388. client = None
  389. tmp_file_path = None # 用于跟踪临时文件路径
  390. try:
  391. # ------------------------- 数据库连接 -------------------------
  392. client = MongoClient(
  393. mongo_uri,
  394. maxPoolSize=10, # 连接池优化
  395. socketTimeoutMS=5000
  396. )
  397. db = client[args['mongodb_database']]
  398. collection = db[args['model_table']]
  399. # ------------------------- 索引维护 -------------------------
  400. index_name = "model_gen_time_idx"
  401. if index_name not in collection.index_information():
  402. collection.create_index(
  403. [("model_name", 1), ("gen_time", DESCENDING)],
  404. name=index_name
  405. )
  406. print("⏱️ 已创建复合索引")
  407. # ------------------------- 高效查询 -------------------------
  408. model_doc = collection.find_one(
  409. {"model_name": args['model_name']},
  410. sort=[('gen_time', DESCENDING)],
  411. projection={"model_data": 1, "gen_time": 1} # 获取必要字段
  412. )
  413. if not model_doc:
  414. print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
  415. return None
  416. # ------------------------- 内存优化加载 -------------------------
  417. if model_doc:
  418. model_data = model_doc['model_data'] # 获取模型的二进制数据
  419. # # 将二进制数据加载到 BytesIO 缓冲区
  420. # model_buffer = BytesIO(model_data)
  421. # # 确保指针在起始位置
  422. # model_buffer.seek(0)
  423. # # 从缓冲区加载模型
  424. # # 使用 h5py 和 BytesIO 从内存中加载模型
  425. # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
  426. # model = tf.keras.models.load_model(f, custom_objects=custom_objects)
  427. # 创建临时文件
  428. with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
  429. tmp_file.write(model_data)
  430. tmp_file_path = tmp_file.name # 获取临时文件路径
  431. # 从临时文件加载模型
  432. model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects)
  433. print(f"{args['model_name']}模型成功从 MongoDB 加载!")
  434. return model
  435. except tf.errors.NotFoundError as e:
  436. print(f"❌ 模型结构缺失关键组件: {str(e)}")
  437. raise RuntimeError("模型架构不完整") from e
  438. except Exception as e:
  439. print(f"❌ 系统异常: {str(e)}")
  440. raise
  441. finally:
  442. # ------------------------- 资源清理 -------------------------
  443. if client:
  444. client.close()
  445. # 确保删除临时文件
  446. if tmp_file_path and os.path.exists(tmp_file_path):
  447. try:
  448. os.remove(tmp_file_path)
  449. print(f"🧹 已清理临时文件: {tmp_file_path}")
  450. except Exception as cleanup_err:
  451. print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
  452. def get_keras_model_from_mongo(
  453. args: Dict[str, Any],
  454. custom_objects: Optional[Dict[str, Any]] = None
  455. ) -> Optional[tf.keras.Model]:
  456. """
  457. 从MongoDB获取指定模型的最新版本(支持Keras格式)
  458. 参数:
  459. args : dict - 包含以下键的字典:
  460. mongodb_database: 数据库名称
  461. model_table: 集合名称
  462. model_name: 要获取的模型名称
  463. custom_objects: dict - 自定义Keras对象字典
  464. 返回:
  465. tf.keras.Model - 加载成功的Keras模型
  466. """
  467. # ------------------------- 参数校验 -------------------------
  468. required_keys = {'mongodb_database', 'model_table', 'model_name'}
  469. if missing := required_keys - args.keys():
  470. raise ValueError(f"❌ 缺失必要参数: {missing}")
  471. # ------------------------- 环境配置 -------------------------
  472. mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  473. client = None
  474. tmp_file_path = None # 用于跟踪临时文件路径
  475. try:
  476. # ------------------------- 数据库连接 -------------------------
  477. client = MongoClient(
  478. mongo_uri,
  479. maxPoolSize=10,
  480. socketTimeoutMS=5000
  481. )
  482. db = client[args['mongodb_database']]
  483. collection = db[args['model_table']]
  484. # ------------------------- 索引维护 -------------------------
  485. # index_name = "model_gen_time_idx"
  486. # if index_name not in collection.index_information():
  487. # collection.create_index(
  488. # [("model_name", 1), ("gen_time", DESCENDING)],
  489. # name=index_name
  490. # )
  491. # print("⏱️ 已创建复合索引")
  492. # ------------------------- 高效查询 -------------------------
  493. model_doc = collection.find_one(
  494. {"model_name": args['model_name']},
  495. sort=[('gen_time', DESCENDING)],
  496. projection={"model_data": 1, "gen_time": 1, 'params': 1}
  497. )
  498. if not model_doc:
  499. print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
  500. return None
  501. # ------------------------- 内存优化加载 -------------------------
  502. model_data = model_doc['model_data']
  503. model_params = model_doc['params']
  504. # 创建临时文件(自动删除)
  505. with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
  506. tmp_file.write(model_data)
  507. tmp_file_path = tmp_file.name # 记录文件路径
  508. # 从临时文件加载模型
  509. model = tf.keras.models.load_model(
  510. tmp_file_path,
  511. custom_objects=custom_objects
  512. )
  513. print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
  514. return model, model_params
  515. except tf.errors.NotFoundError as e:
  516. print(f"❌ 模型结构缺失关键组件: {str(e)}")
  517. raise RuntimeError("模型架构不完整") from e
  518. except Exception as e:
  519. print(f"❌ 系统异常: {str(e)}")
  520. raise
  521. finally:
  522. # ------------------------- 资源清理 -------------------------
  523. if client:
  524. client.close()
  525. # 确保删除临时文件
  526. if tmp_file_path and os.path.exists(tmp_file_path):
  527. try:
  528. os.remove(tmp_file_path)
  529. print(f"🧹 已清理临时文件: {tmp_file_path}")
  530. except Exception as cleanup_err:
  531. print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
  532. def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[
  533. object, Tuple[object, object]]:
  534. """
  535. 优化版特征缩放器加载函数 - 安全高效获取最新预处理模型
  536. 参数:
  537. args : 必须包含键:
  538. - mongodb_database: 数据库名称
  539. - scaler_table: 集合名称
  540. - model_name: 目标模型名称
  541. only_feature_scaler : 是否仅返回特征缩放器
  542. 返回:
  543. 单个缩放器对象或(feature_scaler, target_scaler)元组
  544. 异常:
  545. ValueError : 参数缺失或类型错误
  546. RuntimeError : 数据操作异常
  547. """
  548. # ------------------------- 参数校验 -------------------------
  549. required_keys = {'mongodb_database', 'scaler_table', 'model_name'}
  550. if missing := required_keys - args.keys():
  551. raise ValueError(f"❌ 缺失必要参数: {missing}")
  552. # ------------------------- 环境配置 -------------------------
  553. mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  554. client = None
  555. try:
  556. # ------------------------- 数据库连接 -------------------------
  557. client = MongoClient(
  558. mongo_uri,
  559. maxPoolSize=20, # 连接池上限
  560. socketTimeoutMS=3000, # 3秒超时
  561. serverSelectionTimeoutMS=5000 # 5秒服务器选择超时
  562. )
  563. db = client[args['mongodb_database']]
  564. collection = db[args['scaler_table']]
  565. # ------------------------- 索引维护 -------------------------
  566. # index_name = "model_gen_time_idx"
  567. # if index_name not in collection.index_information():
  568. # collection.create_index(
  569. # [("model_name", 1), ("gen_time", DESCENDING)],
  570. # name=index_name,
  571. # background=True # 后台构建避免阻塞
  572. # )
  573. # print("⏱️ 已创建特征缩放器复合索引")
  574. # ------------------------- 高效查询 -------------------------
  575. scaler_doc = collection.find_one(
  576. {"model_name": args['model_name']},
  577. sort=[('gen_time', DESCENDING)],
  578. projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1}
  579. )
  580. if not scaler_doc:
  581. raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录")
  582. # ------------------------- 反序列化处理 -------------------------
  583. def load_scaler(data: bytes) -> object:
  584. """安全加载序列化对象"""
  585. with BytesIO(data) as buffer:
  586. buffer.seek(0) # 确保指针复位
  587. try:
  588. return joblib.load(buffer)
  589. except joblib.UnpicklingError as e:
  590. raise RuntimeError("反序列化失败 (可能版本不兼容)") from e
  591. # 特征缩放器加载
  592. feature_data = scaler_doc["feature_scaler"]
  593. if not isinstance(feature_data, bytes):
  594. raise RuntimeError("特征缩放器数据格式异常")
  595. feature_scaler = load_scaler(feature_data)
  596. if only_feature_scaler:
  597. return feature_scaler
  598. # 目标缩放器加载
  599. target_data = scaler_doc["target_scaler"]
  600. if not isinstance(target_data, bytes):
  601. raise RuntimeError("目标缩放器数据格式异常")
  602. target_scaler = load_scaler(target_data)
  603. print(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})")
  604. return feature_scaler, target_scaler
  605. except PyMongoError as e:
  606. raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e
  607. except RuntimeError as e:
  608. raise RuntimeError(f"🔌 mongo操作失败: {str(e)}") from e # 直接传递已封装的异常
  609. except Exception as e:
  610. raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
  611. finally:
  612. # ------------------------- 资源清理 -------------------------
  613. if client:
  614. client.close()
  615. def normalize_key(s):
  616. return s.lower()
  617. def get_xmo_data_from_hive(args):
  618. # 获取 hive 配置部分
  619. hive_config = config['hive']
  620. jdbc_url = hive_config['jdbc_url']
  621. driver_class = hive_config['driver_class']
  622. user = hive_config['user']
  623. password = hive_config['password']
  624. features = config['xmo']['features']
  625. numeric_features = config['xmo']['numeric_features']
  626. if 'moment' not in args or 'farm_id' not in args:
  627. msg_error = 'One or more of the following parameters are missing: moment, farm_id!'
  628. return msg_error
  629. else:
  630. moment = args['moment']
  631. farm_id = args['farm_id']
  632. if 'current_date' in args:
  633. current_date = datetime.strptime(args['current_date'], "%Y%m%d")
  634. else:
  635. current_date = datetime.now()
  636. if 'days' in args:
  637. days = int(args['days']) + 1
  638. else:
  639. days = 1
  640. json_feature = f"nwp_xmo_{moment}"
  641. # 建立连接
  642. conn = jaydebeapi.connect(driver_class, jdbc_url, [user, password], jar_file)
  643. # 查询 Hive 表
  644. cursor = conn.cursor()
  645. query_sql = ""
  646. for i in range(0, days):
  647. sysdate_pre = (current_date + timedelta(days=i)).strftime("%Y%m%d")
  648. if i == 0:
  649. pass
  650. else:
  651. query_sql += "union \n"
  652. query_sql += """select rowkey,datatimestamp,{2} from hbase_forecast.forecast_xmo_d{3}
  653. where rowkey>='{0}-{1}0000' and rowkey<='{0}-{1}2345' \n""".format(
  654. farm_id, sysdate_pre, json_feature, i)
  655. print("query_sql\n", query_sql)
  656. cursor.execute(query_sql)
  657. # 获取列名
  658. columns = [desc[0] for desc in cursor.description]
  659. # 获取所有数据
  660. rows = cursor.fetchall()
  661. # 转成 DataFrame
  662. df = pd.DataFrame(rows, columns=columns)
  663. cursor.close()
  664. conn.close()
  665. df[json_feature] = df[json_feature].apply(lambda x: json.loads(x) if isinstance(x, str) else x)
  666. df_features = pd.json_normalize(df[json_feature])
  667. if 'forecastDatatime' not in df_features.columns:
  668. return "The returned data does not contain the forecastDatetime column — the data might be empty or null!"
  669. else:
  670. df_features['date_time'] = pd.to_datetime(df_features['forecastDatatime'], unit='ms', utc=True).dt.tz_convert(
  671. 'Asia/Shanghai').dt.strftime('%Y-%m-%d %H:%M:%S')
  672. df_features[numeric_features] = df_features[numeric_features].apply(pd.to_numeric, errors='coerce')
  673. return df_features[features]
  674. if __name__ == "__main__":
  675. print("Program starts execution!")
  676. args = {
  677. 'moment': '06',
  678. 'current_date': '20250609',
  679. 'farm_id': 'J00883',
  680. 'days': '13'
  681. }
  682. df = get_xmo_data_from_hive(args)
  683. print(df.head(2),df.shape)
  684. print("server start!")