|
@@ -6,13 +6,26 @@ from sqlalchemy import create_engine
|
|
|
import pickle
|
|
|
from io import BytesIO
|
|
|
import joblib
|
|
|
-import h5py
|
|
|
+import json
|
|
|
import tensorflow as tf
|
|
|
import os
|
|
|
import tempfile
|
|
|
+import jaydebeapi
|
|
|
+import toml
|
|
|
+from typing import Dict, Any, Optional, Union, Tuple
|
|
|
+from datetime import datetime, timedelta
|
|
|
+
|
|
|
+# 读取 toml 配置文件
|
|
|
+current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+with open(os.path.join(current_dir, 'database.toml'), 'r', encoding='utf-8') as f:
|
|
|
+ config = toml.load(f) # 只读的全局配置
|
|
|
+
|
|
|
+jar_file = os.path.join(current_dir, 'jar/hive-jdbc-standalone.jar')
|
|
|
+
|
|
|
|
|
|
def get_data_from_mongo(args):
|
|
|
- mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
|
|
|
+ # 获取 hive 配置部分
|
|
|
+ mongodb_connection = config['mongodb']['mongodb_connection']
|
|
|
mongodb_database = args['mongodb_database']
|
|
|
mongodb_read_table = args['mongodb_read_table']
|
|
|
query_dict = {}
|
|
@@ -42,7 +55,9 @@ def get_data_from_mongo(args):
|
|
|
|
|
|
|
|
|
def get_df_list_from_mongo(args):
|
|
|
- mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',')
|
|
|
+ # 获取 hive 配置部分
|
|
|
+ mongodb_connection = config['mongodb']['mongodb_connection']
|
|
|
+ mongodb_database, mongodb_read_table = args['mongodb_database'], args['mongodb_read_table'].split(',')
|
|
|
df_list = []
|
|
|
client = MongoClient(mongodb_connection)
|
|
|
# 选择数据库(如果数据库不存在,MongoDB 会自动创建)
|
|
@@ -58,6 +73,7 @@ def get_df_list_from_mongo(args):
|
|
|
client.close()
|
|
|
return df_list
|
|
|
|
|
|
+
|
|
|
def insert_data_into_mongo(res_df, args):
|
|
|
"""
|
|
|
插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
|
|
@@ -68,7 +84,8 @@ def insert_data_into_mongo(res_df, args):
|
|
|
- overwrite: 布尔值,True 表示覆盖,False 表示追加
|
|
|
- update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
|
|
|
"""
|
|
|
- mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
|
|
|
+ # 获取 hive 配置部分
|
|
|
+ mongodb_connection = config['mongodb']['mongodb_connection']
|
|
|
mongodb_database = args['mongodb_database']
|
|
|
mongodb_write_table = args['mongodb_write_table']
|
|
|
overwrite = 1
|
|
@@ -119,7 +136,7 @@ def insert_data_into_mongo(res_df, args):
|
|
|
def get_data_fromMysql(params):
|
|
|
mysql_conn = params['mysql_conn']
|
|
|
query_sql = params['query_sql']
|
|
|
- #数据库读取实测气象
|
|
|
+ # 数据库读取实测气象
|
|
|
engine = create_engine(f"mysql+pymysql://{mysql_conn}")
|
|
|
# 定义SQL查询
|
|
|
with engine.connect() as conn:
|
|
@@ -128,8 +145,10 @@ def get_data_fromMysql(params):
|
|
|
|
|
|
|
|
|
def insert_pickle_model_into_mongo(model, args):
|
|
|
- mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
|
|
|
- args['mongodb_database'], args['mongodb_write_table'], args['model_name']
|
|
|
+ # 获取 hive 配置部分
|
|
|
+ mongodb_connection = config['mongodb']['mongodb_connection']
|
|
|
+ mongodb_database, mongodb_write_table, model_name = args['mongodb_database'], args['mongodb_write_table'], args[
|
|
|
+ 'model_name']
|
|
|
client = MongoClient(mongodb_connection)
|
|
|
db = client[mongodb_database]
|
|
|
# 序列化模型
|
|
@@ -149,9 +168,27 @@ def insert_pickle_model_into_mongo(model, args):
|
|
|
print("model inserted successfully!")
|
|
|
|
|
|
|
|
|
-def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,args):
|
|
|
- mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
|
|
|
- args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
|
|
|
+def get_pickle_model_from_mongo(args):
|
|
|
+ mongodb_connection = config['mongodb']['mongodb_connection']
|
|
|
+ mongodb_database, mongodb_model_table, model_name = args['mongodb_database'], args['mongodb_model_table'], args['model_name']
|
|
|
+ client = MongoClient(mongodb_connection)
|
|
|
+ db = client[mongodb_database]
|
|
|
+ collection = db[mongodb_model_table]
|
|
|
+ model_data = collection.find_one({"model_name": model_name})
|
|
|
+ if model_data is not None:
|
|
|
+ model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
|
|
|
+ # 反序列化模型
|
|
|
+ model = pickle.loads(model_binary)
|
|
|
+ return model
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def insert_h5_model_into_mongo(model, feature_scaler_bytes, target_scaler_bytes, args):
|
|
|
+ # 获取 hive 配置部分
|
|
|
+ mongodb_connection = config['mongodb']['mongodb_connection']
|
|
|
+ mongodb_database, scaler_table, model_table, model_name = args['mongodb_database'], args['scaler_table'], args[
|
|
|
+ 'model_table'], args['model_name']
|
|
|
client = MongoClient(mongodb_connection)
|
|
|
db = client[mongodb_database]
|
|
|
if scaler_table in db.list_collection_names():
|
|
@@ -209,65 +246,341 @@ def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,a
|
|
|
print(f"⚠️ 临时文件清理失败: {temp_path}")
|
|
|
|
|
|
|
|
|
-# def insert_trained_model_into_mongo(model, args):
|
|
|
-# mongodb_connection,mongodb_database,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
|
|
|
-# args['mongodb_database'],args['model_table'],args['model_name'])
|
|
|
-#
|
|
|
-# gen_time, params_json, descr = args['gen_time'], args['params'], args['descr']
|
|
|
-# client = MongoClient(mongodb_connection)
|
|
|
-# db = client[mongodb_database]
|
|
|
-# if model_table in db.list_collection_names():
|
|
|
-# db[model_table].drop()
|
|
|
-# print(f"Collection '{model_table} already exist, deleted successfully!")
|
|
|
-# model_table = db[model_table]
|
|
|
-#
|
|
|
-# # 创建 BytesIO 缓冲区
|
|
|
-# model_buffer = BytesIO()
|
|
|
-# # 将模型保存为 HDF5 格式到内存 (BytesIO)
|
|
|
-# model.save(model_buffer, save_format='h5')
|
|
|
-# # 将指针移到缓冲区的起始位置
|
|
|
-# model_buffer.seek(0)
|
|
|
-# # 获取模型的二进制数据
|
|
|
-# model_data = model_buffer.read()
|
|
|
-# # 将模型保存到 MongoDB
|
|
|
-# model_table.insert_one({
|
|
|
-# "model_name": model_name,
|
|
|
-# "model_data": model_data,
|
|
|
-# "gen_time": gen_time,
|
|
|
-# "params": params_json,
|
|
|
-# "descr": descr
|
|
|
-# })
|
|
|
-# print("模型成功保存到 MongoDB!")
|
|
|
-
|
|
|
-def insert_scaler_model_into_mongo(feature_scaler_bytes, scaled_target_bytes, args):
|
|
|
- mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
|
|
|
- args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
|
|
|
- client = MongoClient(mongodb_connection)
|
|
|
- db = client[mongodb_database]
|
|
|
- if scaler_table in db.list_collection_names():
|
|
|
- db[scaler_table].drop()
|
|
|
- print(f"Collection '{scaler_table} already exist, deleted successfully!")
|
|
|
- collection = db[scaler_table] # 集合名称
|
|
|
- # Save the scalers in MongoDB as binary data
|
|
|
- collection.insert_one({
|
|
|
- "feature_scaler": feature_scaler_bytes.read(),
|
|
|
- "target_scaler": scaled_target_bytes.read()
|
|
|
- })
|
|
|
- client.close()
|
|
|
- print("scaler_model inserted successfully!")
|
|
|
+def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any]) -> str:
|
|
|
+ """
|
|
|
+ 将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型
|
|
|
+ 参数:
|
|
|
+ model : keras模型 - 训练好的Keras模型
|
|
|
+ args : dict - 包含以下键的字典:
|
|
|
+ mongodb_database: 数据库名称
|
|
|
+ model_table: 集合名称
|
|
|
+ model_name: 模型名称
|
|
|
+ gen_time: 模型生成时间(datetime对象)
|
|
|
+ params: 模型参数(JSON可序列化对象)
|
|
|
+ descr: 模型描述文本
|
|
|
+ """
|
|
|
+ # ------------------------- 参数校验 -------------------------
|
|
|
+ required_keys = {'mongodb_database', 'model_table', 'model_name',
|
|
|
+ 'gen_time', 'params', 'descr'}
|
|
|
+ if missing := required_keys - args.keys():
|
|
|
+ raise ValueError(f"缺少必要参数: {missing}")
|
|
|
|
|
|
+ # ------------------------- 配置解耦 -------------------------
|
|
|
+ # 从环境变量获取连接信息(更安全)
|
|
|
+ mongodb_connection = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
|
|
|
|
|
|
-def get_h5_model_from_mongo(args, custom=None):
|
|
|
- mongodb_connection,mongodb_database,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['model_table'],args['model_name']
|
|
|
- client = MongoClient(mongodb_connection)
|
|
|
- # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
|
|
|
- db = client[mongodb_database]
|
|
|
- collection = db[model_table] # 集合名称
|
|
|
+ # ------------------------- 资源初始化 -------------------------
|
|
|
+ fd, temp_path = None, None
|
|
|
+ client = None
|
|
|
+
|
|
|
+ try:
|
|
|
+ # ------------------------- 临时文件处理 -------------------------
|
|
|
+ fd, temp_path = tempfile.mkstemp(suffix='.keras')
|
|
|
+ os.close(fd) # 立即释放文件锁
|
|
|
+
|
|
|
+ # ------------------------- 模型保存 -------------------------
|
|
|
+ try:
|
|
|
+ model.save(temp_path) # 不指定save_format,默认使用keras新格式
|
|
|
+ except Exception as e:
|
|
|
+ raise RuntimeError(f"模型保存失败: {str(e)}") from e
|
|
|
+
|
|
|
+ # ------------------------- 数据库连接 -------------------------
|
|
|
+ client = MongoClient(mongodb_connection)
|
|
|
+ db = client[args['mongodb_database']]
|
|
|
+ collection = db[args['model_table']]
|
|
|
+
|
|
|
+ # ------------------------- 索引检查 -------------------------
|
|
|
+ # index_info = collection.index_information()
|
|
|
+ # if "gen_time_1" not in index_info:
|
|
|
+ # print("开始创建索引...")
|
|
|
+ # collection.create_index(
|
|
|
+ # [("gen_time", ASCENDING)],
|
|
|
+ # name="gen_time_1",
|
|
|
+ # background=True
|
|
|
+ # )
|
|
|
+ # print("索引创建成功")
|
|
|
+ # else:
|
|
|
+ # print("索引已存在,跳过创建")
|
|
|
+
|
|
|
+ # ------------------------- 容量控制 -------------------------
|
|
|
+ # 使用更高效的计数方式
|
|
|
+ if collection.estimated_document_count() >= 50:
|
|
|
+ # 原子性删除操作
|
|
|
+ if deleted := collection.find_one_and_delete(
|
|
|
+ sort=[("gen_time", ASCENDING)],
|
|
|
+ projection={"_id": 1, "model_name": 1, "gen_time": 1}
|
|
|
+ ):
|
|
|
+ print(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
|
|
|
|
|
|
- # 查询 MongoDB 获取模型数据
|
|
|
- model_doc = collection.find_one({"model_name": model_name})
|
|
|
- if model_doc:
|
|
|
- model_data = model_doc['model_data'] # 获取模型的二进制数据
|
|
|
+ # ------------------------- 数据插入 -------------------------
|
|
|
+ with open(temp_path, 'rb') as f:
|
|
|
+ result = collection.insert_one({
|
|
|
+ "model_name": args['model_name'],
|
|
|
+ "model_data": f.read(),
|
|
|
+ "gen_time": args['gen_time'],
|
|
|
+ "params": args['params'],
|
|
|
+ "descr": args['descr']
|
|
|
+ })
|
|
|
+
|
|
|
+ print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
|
|
|
+ return str(result.inserted_id)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ # ------------------------- 异常分类处理 -------------------------
|
|
|
+ error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
|
|
|
+ print(f"❌ {error_type} - 详细错误: {str(e)}")
|
|
|
+ raise # 根据业务需求决定是否重新抛出
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # ------------------------- 资源清理 -------------------------
|
|
|
+ if client:
|
|
|
+ client.close()
|
|
|
+ if temp_path and os.path.exists(temp_path):
|
|
|
+ try:
|
|
|
+ os.remove(temp_path)
|
|
|
+ except PermissionError:
|
|
|
+ print(f"⚠️ 临时文件清理失败: {temp_path}")
|
|
|
+
|
|
|
+
|
|
|
+def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO,
|
|
|
+ args: Dict[str, Any]) -> str:
|
|
|
+ """
|
|
|
+ 将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ feature_scaler_bytes: BytesIO - 特征缩放器字节流
|
|
|
+ scaled_target_bytes: BytesIO - 目标缩放器字节流
|
|
|
+ args : dict - 包含以下键的字典:
|
|
|
+ mongodb_database: 数据库名称
|
|
|
+ scaler_table: 集合名称
|
|
|
+ model_name: 关联模型名称
|
|
|
+ gen_time: 生成时间(datetime对象)
|
|
|
+ """
|
|
|
+ # ------------------------- 参数校验 -------------------------
|
|
|
+ required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'}
|
|
|
+ if missing := required_keys - args.keys():
|
|
|
+ raise ValueError(f"缺少必要参数: {missing}")
|
|
|
+
|
|
|
+ # ------------------------- 配置解耦 -------------------------
|
|
|
+ # 从环境变量获取连接信息(安全隔离凭证)
|
|
|
+ mongodb_conn = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
|
|
|
+
|
|
|
+ # ------------------------- 输入验证 -------------------------
|
|
|
+ for buf, name in [(feature_scaler_bytes, "特征缩放器"),
|
|
|
+ (target_scaler_bytes, "目标缩放器")]:
|
|
|
+ if not isinstance(buf, BytesIO):
|
|
|
+ raise TypeError(f"{name} 必须为BytesIO类型")
|
|
|
+ if buf.getbuffer().nbytes == 0:
|
|
|
+ raise ValueError(f"{name} 字节流为空")
|
|
|
+
|
|
|
+ client = None
|
|
|
+ try:
|
|
|
+ # ------------------------- 数据库连接 -------------------------
|
|
|
+ client = MongoClient(mongodb_conn)
|
|
|
+ db = client[args['mongodb_database']]
|
|
|
+ collection = db[args['scaler_table']]
|
|
|
+
|
|
|
+ # ------------------------- 索引维护 -------------------------
|
|
|
+ # if "gen_time_1" not in collection.index_information():
|
|
|
+ # collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
|
|
|
+ # print("⏱️ 已创建时间排序索引")
|
|
|
+
|
|
|
+ # ------------------------- 容量控制 -------------------------
|
|
|
+ # 使用近似计数提升性能(误差在几十条内可接受)
|
|
|
+ if collection.estimated_document_count() >= 50:
|
|
|
+ # 原子性删除操作(保证事务完整性)
|
|
|
+ if deleted := collection.find_one_and_delete(
|
|
|
+ sort=[("gen_time", ASCENDING)],
|
|
|
+ projection={"_id": 1, "model_name": 1, "gen_time": 1}
|
|
|
+ ):
|
|
|
+ print(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
|
|
|
+
|
|
|
+ # ------------------------- 数据插入 -------------------------
|
|
|
+ # 确保字节流指针位置正确
|
|
|
+ feature_scaler_bytes.seek(0)
|
|
|
+ target_scaler_bytes.seek(0)
|
|
|
+
|
|
|
+ result = collection.insert_one({
|
|
|
+ "model_name": args['model_name'],
|
|
|
+ "gen_time": args['gen_time'],
|
|
|
+ "feature_scaler": feature_scaler_bytes.read(),
|
|
|
+ "target_scaler": target_scaler_bytes.read()
|
|
|
+ })
|
|
|
+
|
|
|
+ print(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
|
|
|
+ return str(result.inserted_id)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ # ------------------------- 异常分类处理 -------------------------
|
|
|
+ error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误"
|
|
|
+ print(f"❌ {error_type}异常 - 详细错误: {str(e)}")
|
|
|
+ raise # 根据业务需求决定是否重新抛出
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # ------------------------- 资源清理 -------------------------
|
|
|
+ if client:
|
|
|
+ client.close()
|
|
|
+ # 重置字节流指针(确保后续可复用)
|
|
|
+ feature_scaler_bytes.seek(0)
|
|
|
+ target_scaler_bytes.seek(0)
|
|
|
+
|
|
|
+
|
|
|
+def get_h5_model_from_mongo(args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[
|
|
|
+ tf.keras.Model]:
|
|
|
+ """
|
|
|
+ 从MongoDB获取指定模型的最新版本
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ args : dict - 包含以下键的字典:
|
|
|
+ mongodb_database: 数据库名称
|
|
|
+ model_table: 集合名称
|
|
|
+ model_name: 要获取的模型名称
|
|
|
+ custom_objects: dict - 自定义Keras对象字典
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ tf.keras.Model - 加载成功的Keras模型
|
|
|
+ """
|
|
|
+ # ------------------------- 参数校验 -------------------------
|
|
|
+ required_keys = {'mongodb_database', 'model_table', 'model_name'}
|
|
|
+ if missing := required_keys - args.keys():
|
|
|
+ raise ValueError(f"❌ 缺失必要参数: {missing}")
|
|
|
+
|
|
|
+ # ------------------------- 环境配置 -------------------------
|
|
|
+ mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
|
|
|
+ client = None
|
|
|
+ tmp_file_path = None # 用于跟踪临时文件路径
|
|
|
+ try:
|
|
|
+ # ------------------------- 数据库连接 -------------------------
|
|
|
+ client = MongoClient(
|
|
|
+ mongo_uri,
|
|
|
+ maxPoolSize=10, # 连接池优化
|
|
|
+ socketTimeoutMS=5000
|
|
|
+ )
|
|
|
+ db = client[args['mongodb_database']]
|
|
|
+ collection = db[args['model_table']]
|
|
|
+
|
|
|
+ # ------------------------- 索引维护 -------------------------
|
|
|
+ index_name = "model_gen_time_idx"
|
|
|
+ if index_name not in collection.index_information():
|
|
|
+ collection.create_index(
|
|
|
+ [("model_name", 1), ("gen_time", DESCENDING)],
|
|
|
+ name=index_name
|
|
|
+ )
|
|
|
+ print("⏱️ 已创建复合索引")
|
|
|
+
|
|
|
+ # ------------------------- 高效查询 -------------------------
|
|
|
+ model_doc = collection.find_one(
|
|
|
+ {"model_name": args['model_name']},
|
|
|
+ sort=[('gen_time', DESCENDING)],
|
|
|
+ projection={"model_data": 1, "gen_time": 1} # 获取必要字段
|
|
|
+ )
|
|
|
+
|
|
|
+ if not model_doc:
|
|
|
+ print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # ------------------------- 内存优化加载 -------------------------
|
|
|
+ if model_doc:
|
|
|
+ model_data = model_doc['model_data'] # 获取模型的二进制数据
|
|
|
+ # # 将二进制数据加载到 BytesIO 缓冲区
|
|
|
+ # model_buffer = BytesIO(model_data)
|
|
|
+ # # 确保指针在起始位置
|
|
|
+ # model_buffer.seek(0)
|
|
|
+ # # 从缓冲区加载模型
|
|
|
+ # # 使用 h5py 和 BytesIO 从内存中加载模型
|
|
|
+ # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
|
|
|
+ # model = tf.keras.models.load_model(f, custom_objects=custom_objects)
|
|
|
+ # 创建临时文件
|
|
|
+ with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
|
|
|
+ tmp_file.write(model_data)
|
|
|
+ tmp_file_path = tmp_file.name # 获取临时文件路径
|
|
|
+
|
|
|
+ # 从临时文件加载模型
|
|
|
+ model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects)
|
|
|
+
|
|
|
+ print(f"{args['model_name']}模型成功从 MongoDB 加载!")
|
|
|
+ return model
|
|
|
+ except tf.errors.NotFoundError as e:
|
|
|
+ print(f"❌ 模型结构缺失关键组件: {str(e)}")
|
|
|
+ raise RuntimeError("模型架构不完整") from e
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 系统异常: {str(e)}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # ------------------------- 资源清理 -------------------------
|
|
|
+ if client:
|
|
|
+ client.close()
|
|
|
+ # 确保删除临时文件
|
|
|
+ if tmp_file_path and os.path.exists(tmp_file_path):
|
|
|
+ try:
|
|
|
+ os.remove(tmp_file_path)
|
|
|
+ print(f"🧹 已清理临时文件: {tmp_file_path}")
|
|
|
+ except Exception as cleanup_err:
|
|
|
+ print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
|
|
|
+
|
|
|
+
|
|
|
+def get_keras_model_from_mongo(
|
|
|
+ args: Dict[str, Any],
|
|
|
+ custom_objects: Optional[Dict[str, Any]] = None
|
|
|
+) -> Optional[tf.keras.Model]:
|
|
|
+ """
|
|
|
+ 从MongoDB获取指定模型的最新版本(支持Keras格式)
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ args : dict - 包含以下键的字典:
|
|
|
+ mongodb_database: 数据库名称
|
|
|
+ model_table: 集合名称
|
|
|
+ model_name: 要获取的模型名称
|
|
|
+ custom_objects: dict - 自定义Keras对象字典
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ tf.keras.Model - 加载成功的Keras模型
|
|
|
+ """
|
|
|
+ # ------------------------- 参数校验 -------------------------
|
|
|
+ required_keys = {'mongodb_database', 'model_table', 'model_name'}
|
|
|
+ if missing := required_keys - args.keys():
|
|
|
+ raise ValueError(f"❌ 缺失必要参数: {missing}")
|
|
|
+
|
|
|
+ # ------------------------- 环境配置 -------------------------
|
|
|
+ mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
|
|
|
+ client = None
|
|
|
+ tmp_file_path = None # 用于跟踪临时文件路径
|
|
|
+
|
|
|
+ try:
|
|
|
+ # ------------------------- 数据库连接 -------------------------
|
|
|
+ client = MongoClient(
|
|
|
+ mongo_uri,
|
|
|
+ maxPoolSize=10,
|
|
|
+ socketTimeoutMS=5000
|
|
|
+ )
|
|
|
+ db = client[args['mongodb_database']]
|
|
|
+ collection = db[args['model_table']]
|
|
|
+
|
|
|
+ # ------------------------- 索引维护 -------------------------
|
|
|
+ # index_name = "model_gen_time_idx"
|
|
|
+ # if index_name not in collection.index_information():
|
|
|
+ # collection.create_index(
|
|
|
+ # [("model_name", 1), ("gen_time", DESCENDING)],
|
|
|
+ # name=index_name
|
|
|
+ # )
|
|
|
+ # print("⏱️ 已创建复合索引")
|
|
|
+
|
|
|
+ # ------------------------- 高效查询 -------------------------
|
|
|
+ model_doc = collection.find_one(
|
|
|
+ {"model_name": args['model_name']},
|
|
|
+ sort=[('gen_time', DESCENDING)],
|
|
|
+ projection={"model_data": 1, "gen_time": 1, 'params': 1}
|
|
|
+ )
|
|
|
+
|
|
|
+ if not model_doc:
|
|
|
+ print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # ------------------------- 内存优化加载 -------------------------
|
|
|
+ model_data = model_doc['model_data']
|
|
|
+ model_params = model_doc['params']
|
|
|
# 创建临时文件(自动删除)
|
|
|
with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
|
|
|
tmp_file.write(model_data)
|
|
@@ -276,11 +589,25 @@ def get_h5_model_from_mongo(args, custom=None):
|
|
|
# 从临时文件加载模型
|
|
|
model = tf.keras.models.load_model(
|
|
|
tmp_file_path,
|
|
|
- custom_objects=custom
|
|
|
+ custom_objects=custom_objects
|
|
|
)
|
|
|
|
|
|
print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
|
|
|
- client.close()
|
|
|
+ return model, model_params
|
|
|
+
|
|
|
+ except tf.errors.NotFoundError as e:
|
|
|
+ print(f"❌ 模型结构缺失关键组件: {str(e)}")
|
|
|
+ raise RuntimeError("模型架构不完整") from e
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 系统异常: {str(e)}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # ------------------------- 资源清理 -------------------------
|
|
|
+ if client:
|
|
|
+ client.close()
|
|
|
+
|
|
|
# 确保删除临时文件
|
|
|
if tmp_file_path and os.path.exists(tmp_file_path):
|
|
|
try:
|
|
@@ -288,28 +615,180 @@ def get_h5_model_from_mongo(args, custom=None):
|
|
|
print(f"🧹 已清理临时文件: {tmp_file_path}")
|
|
|
except Exception as cleanup_err:
|
|
|
print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
|
|
|
- return model
|
|
|
- else:
|
|
|
- print(f"未找到model_name为 {model_name} 的模型。")
|
|
|
- client.close()
|
|
|
- return None
|
|
|
|
|
|
|
|
|
-def get_scaler_model_from_mongo(args, only_feature_scaler=False):
|
|
|
- mongodb_connection, mongodb_database, scaler_table, = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'], args['scaler_table'])
|
|
|
- client = MongoClient(mongodb_connection)
|
|
|
- # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
|
|
|
- db = client[mongodb_database]
|
|
|
- collection = db[scaler_table] # 集合名称
|
|
|
- # Retrieve the scalers from MongoDB
|
|
|
- scaler_doc = collection.find_one()
|
|
|
- # Deserialize the scalers
|
|
|
-
|
|
|
- feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"])
|
|
|
- feature_scaler = joblib.load(feature_scaler_bytes)
|
|
|
- if only_feature_scaler:
|
|
|
- return feature_scaler
|
|
|
- target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
|
|
|
- target_scaler = joblib.load(target_scaler_bytes)
|
|
|
- client.close()
|
|
|
- return feature_scaler,target_scaler
|
|
|
+def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[
|
|
|
+ object, Tuple[object, object]]:
|
|
|
+ """
|
|
|
+ 优化版特征缩放器加载函数 - 安全高效获取最新预处理模型
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ args : 必须包含键:
|
|
|
+ - mongodb_database: 数据库名称
|
|
|
+ - scaler_table: 集合名称
|
|
|
+ - model_name: 目标模型名称
|
|
|
+ only_feature_scaler : 是否仅返回特征缩放器
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ 单个缩放器对象或(feature_scaler, target_scaler)元组
|
|
|
+
|
|
|
+ 异常:
|
|
|
+ ValueError : 参数缺失或类型错误
|
|
|
+ RuntimeError : 数据操作异常
|
|
|
+ """
|
|
|
+ # ------------------------- 参数校验 -------------------------
|
|
|
+ required_keys = {'mongodb_database', 'scaler_table', 'model_name'}
|
|
|
+ if missing := required_keys - args.keys():
|
|
|
+ raise ValueError(f"❌ 缺失必要参数: {missing}")
|
|
|
+
|
|
|
+ # ------------------------- 环境配置 -------------------------
|
|
|
+ mongo_uri = os.getenv("MONGO_URI", config['mongodb']['mongodb_connection'])
|
|
|
+
|
|
|
+ client = None
|
|
|
+ try:
|
|
|
+ # ------------------------- 数据库连接 -------------------------
|
|
|
+ client = MongoClient(
|
|
|
+ mongo_uri,
|
|
|
+ maxPoolSize=20, # 连接池上限
|
|
|
+ socketTimeoutMS=3000, # 3秒超时
|
|
|
+ serverSelectionTimeoutMS=5000 # 5秒服务器选择超时
|
|
|
+ )
|
|
|
+ db = client[args['mongodb_database']]
|
|
|
+ collection = db[args['scaler_table']]
|
|
|
+
|
|
|
+ # ------------------------- 索引维护 -------------------------
|
|
|
+ # index_name = "model_gen_time_idx"
|
|
|
+ # if index_name not in collection.index_information():
|
|
|
+ # collection.create_index(
|
|
|
+ # [("model_name", 1), ("gen_time", DESCENDING)],
|
|
|
+ # name=index_name,
|
|
|
+ # background=True # 后台构建避免阻塞
|
|
|
+ # )
|
|
|
+ # print("⏱️ 已创建特征缩放器复合索引")
|
|
|
+
|
|
|
+ # ------------------------- 高效查询 -------------------------
|
|
|
+ scaler_doc = collection.find_one(
|
|
|
+ {"model_name": args['model_name']},
|
|
|
+ sort=[('gen_time', DESCENDING)],
|
|
|
+ projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1}
|
|
|
+ )
|
|
|
+
|
|
|
+ if not scaler_doc:
|
|
|
+ raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录")
|
|
|
+
|
|
|
+ # ------------------------- 反序列化处理 -------------------------
|
|
|
+ def load_scaler(data: bytes) -> object:
|
|
|
+ """安全加载序列化对象"""
|
|
|
+ with BytesIO(data) as buffer:
|
|
|
+ buffer.seek(0) # 确保指针复位
|
|
|
+ try:
|
|
|
+ return joblib.load(buffer)
|
|
|
+ except joblib.UnpicklingError as e:
|
|
|
+ raise RuntimeError("反序列化失败 (可能版本不兼容)") from e
|
|
|
+
|
|
|
+ # 特征缩放器加载
|
|
|
+ feature_data = scaler_doc["feature_scaler"]
|
|
|
+ if not isinstance(feature_data, bytes):
|
|
|
+ raise RuntimeError("特征缩放器数据格式异常")
|
|
|
+ feature_scaler = load_scaler(feature_data)
|
|
|
+
|
|
|
+ if only_feature_scaler:
|
|
|
+ return feature_scaler
|
|
|
+
|
|
|
+ # 目标缩放器加载
|
|
|
+ target_data = scaler_doc["target_scaler"]
|
|
|
+ if not isinstance(target_data, bytes):
|
|
|
+ raise RuntimeError("目标缩放器数据格式异常")
|
|
|
+ target_scaler = load_scaler(target_data)
|
|
|
+
|
|
|
+ print(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})")
|
|
|
+ return feature_scaler, target_scaler
|
|
|
+
|
|
|
+ except PyMongoError as e:
|
|
|
+ raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e
|
|
|
+ except RuntimeError as e:
|
|
|
+ raise RuntimeError(f"🔌 mongo操作失败: {str(e)}") from e # 直接传递已封装的异常
|
|
|
+ except Exception as e:
|
|
|
+ raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
|
|
|
+ finally:
|
|
|
+ # ------------------------- 资源清理 -------------------------
|
|
|
+ if client:
|
|
|
+ client.close()
|
|
|
+
|
|
|
+
|
|
|
+def normalize_key(s):
|
|
|
+ return s.lower()
|
|
|
+
|
|
|
+
|
|
|
+def get_xmo_data_from_hive(args):
|
|
|
+ # 获取 hive 配置部分
|
|
|
+ hive_config = config['hive']
|
|
|
+ jdbc_url = hive_config['jdbc_url']
|
|
|
+ driver_class = hive_config['driver_class']
|
|
|
+ user = hive_config['user']
|
|
|
+ password = hive_config['password']
|
|
|
+ features = config['xmo']['features']
|
|
|
+ numeric_features = config['xmo']['numeric_features']
|
|
|
+ if 'moment' not in args or 'farm_id' not in args:
|
|
|
+ msg_error = 'One or more of the following parameters are missing: moment, farm_id!'
|
|
|
+ return msg_error
|
|
|
+ else:
|
|
|
+ moment = args['moment']
|
|
|
+ farm_id = args['farm_id']
|
|
|
+
|
|
|
+ if 'current_date' in args:
|
|
|
+ current_date = datetime.strptime(args['current_date'], "%Y%m%d")
|
|
|
+ else:
|
|
|
+ current_date = datetime.now()
|
|
|
+ if 'days' in args:
|
|
|
+ days = int(args['days']) + 1
|
|
|
+ else:
|
|
|
+ days = 1
|
|
|
+ json_feature = f"nwp_xmo_{moment}"
|
|
|
+ # 建立连接
|
|
|
+ conn = jaydebeapi.connect(driver_class, jdbc_url, [user, password], jar_file)
|
|
|
+ # 查询 Hive 表
|
|
|
+ cursor = conn.cursor()
|
|
|
+ query_sql = ""
|
|
|
+ for i in range(0, days):
|
|
|
+ sysdate_pre = (current_date + timedelta(days=i)).strftime("%Y%m%d")
|
|
|
+ if i == 0:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ query_sql += "union \n"
|
|
|
+
|
|
|
+ query_sql += """select rowkey,datatimestamp,{2} from hbase_forecast.forecast_xmo_d{3}
|
|
|
+ where rowkey>='{0}-{1}0000' and rowkey<='{0}-{1}2345' \n""".format(
|
|
|
+ farm_id, sysdate_pre, json_feature, i)
|
|
|
+ print("query_sql\n", query_sql)
|
|
|
+ cursor.execute(query_sql)
|
|
|
+ # 获取列名
|
|
|
+ columns = [desc[0] for desc in cursor.description]
|
|
|
+ # 获取所有数据
|
|
|
+ rows = cursor.fetchall()
|
|
|
+ # 转成 DataFrame
|
|
|
+ df = pd.DataFrame(rows, columns=columns)
|
|
|
+ cursor.close()
|
|
|
+ conn.close()
|
|
|
+ df[json_feature] = df[json_feature].apply(lambda x: json.loads(x) if isinstance(x, str) else x)
|
|
|
+ df_features = pd.json_normalize(df[json_feature])
|
|
|
+ if 'forecastDatatime' not in df_features.columns:
|
|
|
+ return "The returned data does not contain the forecastDatetime column — the data might be empty or null!"
|
|
|
+ else:
|
|
|
+ df_features['date_time'] = pd.to_datetime(df_features['forecastDatatime'], unit='ms', utc=True).dt.tz_convert(
|
|
|
+ 'Asia/Shanghai').dt.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ df_features[numeric_features] = df_features[numeric_features].apply(pd.to_numeric, errors='coerce')
|
|
|
+ return df_features[features]
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ print("Program starts execution!")
|
|
|
+ args = {
|
|
|
+ 'moment': '06',
|
|
|
+ 'current_date': '20250609',
|
|
|
+ 'farm_id': 'J00883',
|
|
|
+ 'days': '13'
|
|
|
+ }
|
|
|
+ df = get_xmo_data_from_hive(args)
|
|
|
+ print(df.head(2),df.shape)
|
|
|
+ print("server start!")
|