database_dml.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825
  1. import pymongo
  2. from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING
  3. from pymongo.errors import PyMongoError
  4. import pandas as pd
  5. from sqlalchemy import create_engine
  6. import pickle
  7. from io import BytesIO
  8. import joblib
  9. import json
  10. import tensorflow as tf
  11. import os
  12. import tempfile
  13. import jaydebeapi
  14. import toml
  15. from typing import Dict, Any, Optional, Union, Tuple
  16. from datetime import datetime, timedelta
  17. # 读取 toml 配置文件
  18. current_dir = os.path.dirname(os.path.abspath(__file__))
  19. with open(os.path.join(current_dir, 'database.toml'), 'r', encoding='utf-8') as f:
  20. config = toml.load(f) # 只读的全局配置
  21. jar_file = os.path.join(current_dir, 'jar/hive-jdbc-standalone.jar')
  22. def get_data_from_mongo(args):
  23. # 获取 hive 配置部分
  24. mongodb_connection = config['mongodb']['mongodb_connection']
  25. mongodb_database = args['mongodb_database']
  26. mongodb_read_table = args['mongodb_read_table']
  27. query_dict = {}
  28. if 'timeBegin' in args.keys():
  29. timeBegin = args['timeBegin']
  30. query_dict.update({"$gte": timeBegin})
  31. if 'timeEnd' in args.keys():
  32. timeEnd = args['timeEnd']
  33. query_dict.update({"$lte": timeEnd})
  34. client = MongoClient(mongodb_connection)
  35. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  36. db = client[mongodb_database]
  37. collection = db[mongodb_read_table] # 集合名称
  38. if len(query_dict) != 0:
  39. query = {"dateTime": query_dict}
  40. cursor = collection.find(query)
  41. else:
  42. cursor = collection.find()
  43. data = list(cursor)
  44. df = pd.DataFrame(data)
  45. # 4. 删除 _id 字段(可选)
  46. if '_id' in df.columns:
  47. df = df.drop(columns=['_id'])
  48. client.close()
  49. return df
  50. def get_df_list_from_mongo(args):
  51. # 获取 hive 配置部分
  52. mongodb_connection = config['mongodb']['mongodb_connection']
  53. mongodb_database, mongodb_read_table = args['mongodb_database'], args['mongodb_read_table'].split(',')
  54. df_list = []
  55. client = MongoClient(mongodb_connection)
  56. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  57. db = client[mongodb_database]
  58. for table in mongodb_read_table:
  59. collection = db[table] # 集合名称
  60. data_from_db = collection.find() # 这会返回一个游标(cursor)
  61. # 将游标转换为列表,并创建 pandas DataFrame
  62. df = pd.DataFrame(list(data_from_db))
  63. if '_id' in df.columns:
  64. df = df.drop(columns=['_id'])
  65. df_list.append(df)
  66. client.close()
  67. return df_list
  68. def delete_data_from_mongo(args):
  69. mongodb_connection = config['mongodb']['mongodb_connection']
  70. mongodb_database = args['mongodb_database']
  71. mongodb_write_table = args['mongodb_write_table']
  72. client = MongoClient(mongodb_connection)
  73. db = client[mongodb_database]
  74. collection = db[mongodb_write_table]
  75. if mongodb_write_table in db.list_collection_names():
  76. collection.drop()
  77. print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!")
  78. else:
  79. print(f"Collection '{mongodb_write_table}' already not exists!")
  80. def insert_data_into_mongo(res_df, args):
  81. """
  82. 插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
  83. 参数:
  84. - res_df: 要插入的 DataFrame 数据
  85. - args: 包含 MongoDB 数据库和集合名称的字典
  86. - overwrite: 布尔值,True 表示覆盖,False 表示追加
  87. - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
  88. """
  89. # 获取 hive 配置部分
  90. mongodb_connection = config['mongodb']['mongodb_connection']
  91. mongodb_database = args['mongodb_database']
  92. mongodb_write_table = args['mongodb_write_table']
  93. overwrite = 1
  94. update_keys = None
  95. if 'overwrite' in args.keys():
  96. overwrite = int(args['overwrite'])
  97. if 'update_keys' in args.keys():
  98. update_keys = args['update_keys'].split(',')
  99. client = MongoClient(mongodb_connection)
  100. db = client[mongodb_database]
  101. collection = db[mongodb_write_table]
  102. # 覆盖模式:删除现有集合
  103. if overwrite:
  104. if mongodb_write_table in db.list_collection_names():
  105. collection.drop()
  106. print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!")
  107. # 将 DataFrame 转为字典格式
  108. data_dict = res_df.to_dict("records") # 每一行作为一个字典
  109. # 如果没有数据,直接返回
  110. if not data_dict:
  111. print("No data to insert.")
  112. return
  113. # 如果指定了 update_keys,则执行 upsert(更新或插入)
  114. if update_keys and not overwrite:
  115. operations = []
  116. for record in data_dict:
  117. # 构建查询条件,用于匹配要更新的文档
  118. query = {key: record[key] for key in update_keys}
  119. operations.append(UpdateOne(query, {'$set': record}, upsert=True))
  120. # 批量执行更新/插入操作
  121. if operations:
  122. result = collection.bulk_write(operations)
  123. client.close()
  124. print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
  125. else:
  126. # 追加模式:直接插入新数据
  127. collection.insert_many(data_dict)
  128. client.close()
  129. print("Data inserted successfully!")
  130. def get_data_fromMysql(params):
  131. mysql_conn = params['mysql_conn']
  132. query_sql = params['query_sql']
  133. # 数据库读取实测气象
  134. engine = create_engine(f"mysql+pymysql://{mysql_conn}")
  135. # 定义SQL查询
  136. with engine.connect() as conn:
  137. df = pd.read_sql_query(query_sql, conn)
  138. return df
  139. def insert_pickle_model_into_mongo(model, args, features=None):
  140. # 获取 hive 配置部分
  141. mongodb_connection = config['mongodb']['mongodb_connection']
  142. mongodb_database, mongodb_write_table, model_name = args['mongodb_database'], args['mongodb_write_table'], args[
  143. 'model_name']
  144. client = MongoClient(mongodb_connection)
  145. db = client[mongodb_database]
  146. # 序列化模型
  147. model_bytes = pickle.dumps(model)
  148. model_data = {
  149. 'model_name': model_name,
  150. 'model': model_bytes, # 将模型字节流存入数据库
  151. }
  152. if features is not None:
  153. model_data['features'] = features
  154. print('Training completed!')
  155. if mongodb_write_table in db.list_collection_names():
  156. db[mongodb_write_table].drop()
  157. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  158. collection = db[mongodb_write_table] # 集合名称
  159. collection.insert_one(model_data)
  160. client.close()
  161. print("model inserted successfully!")
  162. def get_pickle_model_from_mongo(args):
  163. mongodb_connection = config['mongodb']['mongodb_connection']
  164. mongodb_database, mongodb_model_table, model_name = args['mongodb_database'], args['mongodb_model_table'], args['model_name']
  165. client = MongoClient(mongodb_connection)
  166. db = client[mongodb_database]
  167. collection = db[mongodb_model_table]
  168. model_data = collection.find_one({"model_name": model_name})
  169. if model_data is not None:
  170. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  171. # 反序列化模型
  172. model = pickle.loads(model_binary)
  173. return model
  174. else:
  175. return None
  176. def insert_h5_model_into_mongo(model, feature_scaler_bytes, target_scaler_bytes, args):
  177. # 获取 hive 配置部分
  178. mongodb_connection = config['mongodb']['mongodb_connection']
  179. mongodb_database, scaler_table, model_table, model_name = args['mongodb_database'], args['scaler_table'], args[
  180. 'model_table'], args['model_name']
  181. client = MongoClient(mongodb_connection)
  182. db = client[mongodb_database]
  183. if scaler_table in db.list_collection_names():
  184. db[scaler_table].drop()
  185. print(f"Collection '{scaler_table} already exist, deleted successfully!")
  186. collection = db[scaler_table] # 集合名称
  187. # Save the scalers in MongoDB as binary data
  188. collection.insert_one({
  189. "feature_scaler": feature_scaler_bytes.read(),
  190. "target_scaler": target_scaler_bytes.read()
  191. })
  192. print("scaler_model inserted successfully!")
  193. if model_table in db.list_collection_names():
  194. db[model_table].drop()
  195. print(f"Collection '{model_table} already exist, deleted successfully!")
  196. model_table = db[model_table]
  197. fd, temp_path = None, None
  198. client = None
  199. try:
  200. # ------------------------- 临时文件处理 -------------------------
  201. fd, temp_path = tempfile.mkstemp(suffix='.keras')
  202. os.close(fd) # 立即释放文件锁
  203. # ------------------------- 模型保存 -------------------------
  204. try:
  205. model.save(temp_path) # 不指定save_format,默认使用keras新格式
  206. except Exception as e:
  207. raise RuntimeError(f"模型保存失败: {str(e)}") from e
  208. # ------------------------- 数据插入 -------------------------
  209. with open(temp_path, 'rb') as f:
  210. result = model_table.insert_one({
  211. "model_name": args['model_name'],
  212. "model_data": f.read(),
  213. })
  214. print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  215. return str(result.inserted_id)
  216. except Exception as e:
  217. # ------------------------- 异常分类处理 -------------------------
  218. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
  219. print(f"❌ {error_type} - 详细错误: {str(e)}")
  220. raise # 根据业务需求决定是否重新抛出
  221. finally:
  222. # ------------------------- 资源清理 -------------------------
  223. if client:
  224. client.close()
  225. if temp_path and os.path.exists(temp_path):
  226. try:
  227. os.remove(temp_path)
  228. except PermissionError:
  229. print(f"⚠️ 临时文件清理失败: {temp_path}")
  230. def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any]) -> str:
  231. """
  232. 将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型
  233. 参数:
  234. model : keras模型 - 训练好的Keras模型
  235. args : dict - 包含以下键的字典:
  236. mongodb_database: 数据库名称
  237. model_table: 集合名称
  238. model_name: 模型名称
  239. gen_time: 模型生成时间(datetime对象)
  240. params: 模型参数(JSON可序列化对象)
  241. descr: 模型描述文本
  242. """
  243. # ------------------------- 参数校验 -------------------------
  244. required_keys = {'mongodb_database', 'model_table', 'model_name',
  245. 'gen_time', 'params', 'descr'}
  246. if missing := required_keys - args.keys():
  247. raise ValueError(f"缺少必要参数: {missing}")
  248. # ------------------------- 配置解耦 -------------------------
  249. # 从环境变量获取连接信息(更安全)
  250. mongodb_connection = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  251. # ------------------------- 资源初始化 -------------------------
  252. fd, temp_path = None, None
  253. client = None
  254. try:
  255. # ------------------------- 临时文件处理 -------------------------
  256. fd, temp_path = tempfile.mkstemp(suffix='.keras')
  257. os.close(fd) # 立即释放文件锁
  258. # ------------------------- 模型保存 -------------------------
  259. try:
  260. model.save(temp_path) # 不指定save_format,默认使用keras新格式
  261. except Exception as e:
  262. raise RuntimeError(f"模型保存失败: {str(e)}") from e
  263. # ------------------------- 数据库连接 -------------------------
  264. client = MongoClient(mongodb_connection)
  265. db = client[args['mongodb_database']]
  266. collection = db[args['model_table']]
  267. # ------------------------- 索引检查 -------------------------
  268. # index_info = collection.index_information()
  269. # if "gen_time_1" not in index_info:
  270. # print("开始创建索引...")
  271. # collection.create_index(
  272. # [("gen_time", ASCENDING)],
  273. # name="gen_time_1",
  274. # background=True
  275. # )
  276. # print("索引创建成功")
  277. # else:
  278. # print("索引已存在,跳过创建")
  279. # ------------------------- 容量控制 -------------------------
  280. # 使用更高效的计数方式
  281. if collection.estimated_document_count() >= 50:
  282. # 原子性删除操作
  283. if deleted := collection.find_one_and_delete(
  284. sort=[("gen_time", ASCENDING)],
  285. projection={"_id": 1, "model_name": 1, "gen_time": 1}
  286. ):
  287. print(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
  288. # ------------------------- 数据插入 -------------------------
  289. with open(temp_path, 'rb') as f:
  290. result = collection.insert_one({
  291. "model_name": args['model_name'],
  292. "model_data": f.read(),
  293. "gen_time": args['gen_time'],
  294. "params": args['params'],
  295. "descr": args['descr']
  296. })
  297. print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  298. return str(result.inserted_id)
  299. except Exception as e:
  300. # ------------------------- 异常分类处理 -------------------------
  301. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
  302. print(f"❌ {error_type} - 详细错误: {str(e)}")
  303. raise # 根据业务需求决定是否重新抛出
  304. finally:
  305. # ------------------------- 资源清理 -------------------------
  306. if client:
  307. client.close()
  308. if temp_path and os.path.exists(temp_path):
  309. try:
  310. os.remove(temp_path)
  311. except PermissionError:
  312. print(f"⚠️ 临时文件清理失败: {temp_path}")
  313. def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO,
  314. args: Dict[str, Any]) -> str:
  315. """
  316. 将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档
  317. 参数:
  318. feature_scaler_bytes: BytesIO - 特征缩放器字节流
  319. scaled_target_bytes: BytesIO - 目标缩放器字节流
  320. args : dict - 包含以下键的字典:
  321. mongodb_database: 数据库名称
  322. scaler_table: 集合名称
  323. model_name: 关联模型名称
  324. gen_time: 生成时间(datetime对象)
  325. """
  326. # ------------------------- 参数校验 -------------------------
  327. required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'}
  328. if missing := required_keys - args.keys():
  329. raise ValueError(f"缺少必要参数: {missing}")
  330. # ------------------------- 配置解耦 -------------------------
  331. # 从环境变量获取连接信息(安全隔离凭证)
  332. mongodb_conn = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  333. # ------------------------- 输入验证 -------------------------
  334. for buf, name in [(feature_scaler_bytes, "特征缩放器"),
  335. (target_scaler_bytes, "目标缩放器")]:
  336. if not isinstance(buf, BytesIO):
  337. raise TypeError(f"{name} 必须为BytesIO类型")
  338. if buf.getbuffer().nbytes == 0:
  339. raise ValueError(f"{name} 字节流为空")
  340. client = None
  341. try:
  342. # ------------------------- 数据库连接 -------------------------
  343. client = MongoClient(mongodb_conn)
  344. db = client[args['mongodb_database']]
  345. collection = db[args['scaler_table']]
  346. # ------------------------- 索引维护 -------------------------
  347. # if "gen_time_1" not in collection.index_information():
  348. # collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
  349. # print("⏱️ 已创建时间排序索引")
  350. # ------------------------- 容量控制 -------------------------
  351. # 使用近似计数提升性能(误差在几十条内可接受)
  352. if collection.estimated_document_count() >= 50:
  353. # 原子性删除操作(保证事务完整性)
  354. if deleted := collection.find_one_and_delete(
  355. sort=[("gen_time", ASCENDING)],
  356. projection={"_id": 1, "model_name": 1, "gen_time": 1}
  357. ):
  358. print(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
  359. # ------------------------- 数据插入 -------------------------
  360. # 确保字节流指针位置正确
  361. feature_scaler_bytes.seek(0)
  362. target_scaler_bytes.seek(0)
  363. result = collection.insert_one({
  364. "model_name": args['model_name'],
  365. "gen_time": args['gen_time'],
  366. "feature_scaler": feature_scaler_bytes.read(),
  367. "target_scaler": target_scaler_bytes.read()
  368. })
  369. print(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  370. return str(result.inserted_id)
  371. except Exception as e:
  372. # ------------------------- 异常分类处理 -------------------------
  373. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误"
  374. print(f"❌ {error_type}异常 - 详细错误: {str(e)}")
  375. raise # 根据业务需求决定是否重新抛出
  376. finally:
  377. # ------------------------- 资源清理 -------------------------
  378. if client:
  379. client.close()
  380. # 重置字节流指针(确保后续可复用)
  381. feature_scaler_bytes.seek(0)
  382. target_scaler_bytes.seek(0)
  383. def get_h5_model_from_mongo(args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[
  384. tf.keras.Model]:
  385. """
  386. 从MongoDB获取指定模型的最新版本
  387. 参数:
  388. args : dict - 包含以下键的字典:
  389. mongodb_database: 数据库名称
  390. model_table: 集合名称
  391. model_name: 要获取的模型名称
  392. custom_objects: dict - 自定义Keras对象字典
  393. 返回:
  394. tf.keras.Model - 加载成功的Keras模型
  395. """
  396. # ------------------------- 参数校验 -------------------------
  397. required_keys = {'mongodb_database', 'model_table', 'model_name'}
  398. if missing := required_keys - args.keys():
  399. raise ValueError(f"❌ 缺失必要参数: {missing}")
  400. # ------------------------- 环境配置 -------------------------
  401. mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  402. client = None
  403. tmp_file_path = None # 用于跟踪临时文件路径
  404. try:
  405. # ------------------------- 数据库连接 -------------------------
  406. client = MongoClient(
  407. mongo_uri,
  408. maxPoolSize=10, # 连接池优化
  409. socketTimeoutMS=5000
  410. )
  411. db = client[args['mongodb_database']]
  412. collection = db[args['model_table']]
  413. # ------------------------- 索引维护 -------------------------
  414. index_name = "model_gen_time_idx"
  415. if index_name not in collection.index_information():
  416. collection.create_index(
  417. [("model_name", 1), ("gen_time", DESCENDING)],
  418. name=index_name
  419. )
  420. print("⏱️ 已创建复合索引")
  421. # ------------------------- 高效查询 -------------------------
  422. model_doc = collection.find_one(
  423. {"model_name": args['model_name']},
  424. sort=[('gen_time', DESCENDING)],
  425. projection={"model_data": 1, "gen_time": 1} # 获取必要字段
  426. )
  427. if not model_doc:
  428. print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
  429. return None
  430. # ------------------------- 内存优化加载 -------------------------
  431. if model_doc:
  432. model_data = model_doc['model_data'] # 获取模型的二进制数据
  433. # # 将二进制数据加载到 BytesIO 缓冲区
  434. # model_buffer = BytesIO(model_data)
  435. # # 确保指针在起始位置
  436. # model_buffer.seek(0)
  437. # # 从缓冲区加载模型
  438. # # 使用 h5py 和 BytesIO 从内存中加载模型
  439. # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
  440. # model = tf.keras.models.load_model(f, custom_objects=custom_objects)
  441. # 创建临时文件
  442. with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
  443. tmp_file.write(model_data)
  444. tmp_file_path = tmp_file.name # 获取临时文件路径
  445. # 从临时文件加载模型
  446. model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects)
  447. print(f"{args['model_name']}模型成功从 MongoDB 加载!")
  448. return model
  449. except tf.errors.NotFoundError as e:
  450. print(f"❌ 模型结构缺失关键组件: {str(e)}")
  451. raise RuntimeError("模型架构不完整") from e
  452. except Exception as e:
  453. print(f"❌ 系统异常: {str(e)}")
  454. raise
  455. finally:
  456. # ------------------------- 资源清理 -------------------------
  457. if client:
  458. client.close()
  459. # 确保删除临时文件
  460. if tmp_file_path and os.path.exists(tmp_file_path):
  461. try:
  462. os.remove(tmp_file_path)
  463. print(f"🧹 已清理临时文件: {tmp_file_path}")
  464. except Exception as cleanup_err:
  465. print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
  466. def get_keras_model_from_mongo(
  467. args: Dict[str, Any],
  468. custom_objects: Optional[Dict[str, Any]] = None
  469. ) -> Optional[tf.keras.Model]:
  470. """
  471. 从MongoDB获取指定模型的最新版本(支持Keras格式)
  472. 参数:
  473. args : dict - 包含以下键的字典:
  474. mongodb_database: 数据库名称
  475. model_table: 集合名称
  476. model_name: 要获取的模型名称
  477. custom_objects: dict - 自定义Keras对象字典
  478. 返回:
  479. tf.keras.Model - 加载成功的Keras模型
  480. """
  481. # ------------------------- 参数校验 -------------------------
  482. required_keys = {'mongodb_database', 'model_table', 'model_name'}
  483. if missing := required_keys - args.keys():
  484. raise ValueError(f"❌ 缺失必要参数: {missing}")
  485. # ------------------------- 环境配置 -------------------------
  486. mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  487. client = None
  488. tmp_file_path = None # 用于跟踪临时文件路径
  489. try:
  490. # ------------------------- 数据库连接 -------------------------
  491. client = MongoClient(
  492. mongo_uri,
  493. maxPoolSize=10,
  494. socketTimeoutMS=5000
  495. )
  496. db = client[args['mongodb_database']]
  497. collection = db[args['model_table']]
  498. # ------------------------- 索引维护 -------------------------
  499. # index_name = "model_gen_time_idx"
  500. # if index_name not in collection.index_information():
  501. # collection.create_index(
  502. # [("model_name", 1), ("gen_time", DESCENDING)],
  503. # name=index_name
  504. # )
  505. # print("⏱️ 已创建复合索引")
  506. # ------------------------- 高效查询 -------------------------
  507. model_doc = collection.find_one(
  508. {"model_name": args['model_name']},
  509. sort=[('gen_time', DESCENDING)],
  510. projection={"model_data": 1, "gen_time": 1, 'params': 1}
  511. )
  512. if not model_doc:
  513. print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
  514. return None
  515. # ------------------------- 内存优化加载 -------------------------
  516. model_data = model_doc['model_data']
  517. model_params = model_doc['params']
  518. # 创建临时文件(自动删除)
  519. with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
  520. tmp_file.write(model_data)
  521. tmp_file_path = tmp_file.name # 记录文件路径
  522. # 从临时文件加载模型
  523. model = tf.keras.models.load_model(
  524. tmp_file_path,
  525. custom_objects=custom_objects
  526. )
  527. print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
  528. return model, model_params
  529. except tf.errors.NotFoundError as e:
  530. print(f"❌ 模型结构缺失关键组件: {str(e)}")
  531. raise RuntimeError("模型架构不完整") from e
  532. except Exception as e:
  533. print(f"❌ 系统异常: {str(e)}")
  534. raise
  535. finally:
  536. # ------------------------- 资源清理 -------------------------
  537. if client:
  538. client.close()
  539. # 确保删除临时文件
  540. if tmp_file_path and os.path.exists(tmp_file_path):
  541. try:
  542. os.remove(tmp_file_path)
  543. print(f"🧹 已清理临时文件: {tmp_file_path}")
  544. except Exception as cleanup_err:
  545. print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
  546. def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[
  547. object, Tuple[object, object]]:
  548. """
  549. 优化版特征缩放器加载函数 - 安全高效获取最新预处理模型
  550. 参数:
  551. args : 必须包含键:
  552. - mongodb_database: 数据库名称
  553. - scaler_table: 集合名称
  554. - model_name: 目标模型名称
  555. only_feature_scaler : 是否仅返回特征缩放器
  556. 返回:
  557. 单个缩放器对象或(feature_scaler, target_scaler)元组
  558. 异常:
  559. ValueError : 参数缺失或类型错误
  560. RuntimeError : 数据操作异常
  561. """
  562. # ------------------------- 参数校验 -------------------------
  563. required_keys = {'mongodb_database', 'scaler_table', 'model_name'}
  564. if missing := required_keys - args.keys():
  565. raise ValueError(f"❌ 缺失必要参数: {missing}")
  566. # ------------------------- 环境配置 -------------------------
  567. mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
  568. client = None
  569. try:
  570. # ------------------------- 数据库连接 -------------------------
  571. client = MongoClient(
  572. mongo_uri,
  573. maxPoolSize=20, # 连接池上限
  574. socketTimeoutMS=3000, # 3秒超时
  575. serverSelectionTimeoutMS=5000 # 5秒服务器选择超时
  576. )
  577. db = client[args['mongodb_database']]
  578. collection = db[args['scaler_table']]
  579. # ------------------------- 索引维护 -------------------------
  580. # index_name = "model_gen_time_idx"
  581. # if index_name not in collection.index_information():
  582. # collection.create_index(
  583. # [("model_name", 1), ("gen_time", DESCENDING)],
  584. # name=index_name,
  585. # background=True # 后台构建避免阻塞
  586. # )
  587. # print("⏱️ 已创建特征缩放器复合索引")
  588. # ------------------------- 高效查询 -------------------------
  589. scaler_doc = collection.find_one(
  590. {"model_name": args['model_name']},
  591. sort=[('gen_time', DESCENDING)],
  592. projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1}
  593. )
  594. if not scaler_doc:
  595. raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录")
  596. # ------------------------- 反序列化处理 -------------------------
  597. def load_scaler(data: bytes) -> object:
  598. """安全加载序列化对象"""
  599. with BytesIO(data) as buffer:
  600. buffer.seek(0) # 确保指针复位
  601. try:
  602. return joblib.load(buffer)
  603. except joblib.UnpicklingError as e:
  604. raise RuntimeError("反序列化失败 (可能版本不兼容)") from e
  605. # 特征缩放器加载
  606. feature_data = scaler_doc["feature_scaler"]
  607. if not isinstance(feature_data, bytes):
  608. raise RuntimeError("特征缩放器数据格式异常")
  609. feature_scaler = load_scaler(feature_data)
  610. if only_feature_scaler:
  611. return feature_scaler
  612. # 目标缩放器加载
  613. target_data = scaler_doc["target_scaler"]
  614. if not isinstance(target_data, bytes):
  615. raise RuntimeError("目标缩放器数据格式异常")
  616. target_scaler = load_scaler(target_data)
  617. print(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})")
  618. return feature_scaler, target_scaler
  619. except PyMongoError as e:
  620. raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e
  621. except RuntimeError as e:
  622. raise RuntimeError(f"🔌 mongo操作失败: {str(e)}") from e # 直接传递已封装的异常
  623. except Exception as e:
  624. raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
  625. finally:
  626. # ------------------------- 资源清理 -------------------------
  627. if client:
  628. client.close()
  629. def normalize_key(s):
  630. return s.lower()
  631. def get_xmo_data_from_hive(args):
  632. # 获取 hive 配置部分
  633. hive_config = config['hive']
  634. jdbc_url = hive_config['jdbc_url']
  635. driver_class = hive_config['driver_class']
  636. user = hive_config['user']
  637. password = hive_config['password']
  638. features = config['xmo']['features']
  639. numeric_features = config['xmo']['numeric_features']
  640. if 'farm_id' not in args:
  641. msg_error = 'One or more of the following parameters are missing: farm_id!'
  642. return msg_error
  643. else:
  644. farm_id = args['farm_id']
  645. if 'moment' in args:
  646. moment = args['moment']
  647. else:
  648. hour = datetime.strptime(args['current_date'], "%Y%m%d%H").hour
  649. if hour < 3:
  650. moment = '00'
  651. elif hour < 6:
  652. moment = '03'
  653. elif hour < 15:
  654. moment = '09'
  655. elif hour < 19:
  656. moment = '12'
  657. else:
  658. moment = '18'
  659. print(moment)
  660. if 'current_date' in args:
  661. current_date = datetime.strptime(args['current_date'], "%Y%m%d%H")
  662. else:
  663. current_date = datetime.now()
  664. if 'days' in args:
  665. days = int(args['days']) + 1
  666. else:
  667. days = 1
  668. json_feature = f"nwp_xmo_{moment}"
  669. # 建立连接
  670. """"""
  671. conn = jaydebeapi.connect(driver_class, jdbc_url, [user, password], jar_file)
  672. # 查询 Hive 表
  673. cursor = conn.cursor()
  674. query_sql = ""
  675. for i in range(0, days):
  676. sysdate_pre = (current_date + timedelta(days=i)).strftime("%Y%m%d")
  677. if i == 0:
  678. pass
  679. else:
  680. query_sql += "union \n"
  681. query_sql += """select rowkey,datatimestamp,{2} from hbase_forecast.forecast_xmo_d{3}
  682. where rowkey>='{0}-{1}0000' and rowkey<='{0}-{1}2345' \n""".format(
  683. farm_id, sysdate_pre, json_feature, i)
  684. print("query_sql\n", query_sql)
  685. cursor.execute(query_sql)
  686. # 获取列名
  687. columns = [desc[0] for desc in cursor.description]
  688. # 获取所有数据
  689. rows = cursor.fetchall()
  690. # 转成 DataFrame
  691. df = pd.DataFrame(rows, columns=columns)
  692. cursor.close()
  693. conn.close()
  694. df[json_feature] = df[json_feature].apply(lambda x: json.loads(x) if isinstance(x, str) else x)
  695. df_features = pd.json_normalize(df[json_feature])
  696. if 'forecastDatatime' not in df_features.columns:
  697. return "The returned data does not contain the forecastDatetime column — the data might be empty or null!"
  698. else:
  699. df_features['date_time'] = pd.to_datetime(df_features['forecastDatatime'], unit='ms', utc=True).dt.tz_convert(
  700. 'Asia/Shanghai').dt.strftime('%Y-%m-%d %H:%M:%S')
  701. df_features[numeric_features] = df_features[numeric_features].apply(pd.to_numeric, errors='coerce')
  702. return df_features[features]
  703. if __name__ == "__main__":
  704. print("Program starts execution!")
  705. args = {
  706. # 'moment': '06',
  707. 'current_date': '2025060901',
  708. 'farm_id': 'J00883',
  709. 'days': '13'
  710. }
  711. get_xmo_data_from_hive(args)
  712. print("server start!")