123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612 |
- import pymongo
- from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING
- from pymongo.errors import PyMongoError
- import pandas as pd
- from sqlalchemy import create_engine
- import pickle
- from io import BytesIO
- import joblib
- import h5py, os, io
- import tensorflow as tf
- from typing import Dict, Any, Optional, Union, Tuple
- import tempfile
- def get_data_from_mongo(args):
- mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
- mongodb_database = args['mongodb_database']
- mongodb_read_table = args['mongodb_read_table']
- query_dict = {}
- if 'timeBegin' in args.keys():
- timeBegin = args['timeBegin']
- query_dict.update({"$gte": timeBegin})
- if 'timeEnd' in args.keys():
- timeEnd = args['timeEnd']
- query_dict.update({"$lte": timeEnd})
- client = MongoClient(mongodb_connection)
- # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
- db = client[mongodb_database]
- collection = db[mongodb_read_table] # 集合名称
- if len(query_dict) != 0:
- query = {"dateTime": query_dict}
- cursor = collection.find(query)
- else:
- cursor = collection.find()
- data = list(cursor)
- df = pd.DataFrame(data)
- # 4. 删除 _id 字段(可选)
- if '_id' in df.columns:
- df = df.drop(columns=['_id'])
- client.close()
- return df
- def get_df_list_from_mongo(args):
- mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',')
- df_list = []
- client = MongoClient(mongodb_connection)
- # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
- db = client[mongodb_database]
- for table in mongodb_read_table:
- collection = db[table] # 集合名称
- data_from_db = collection.find() # 这会返回一个游标(cursor)
- # 将游标转换为列表,并创建 pandas DataFrame
- df = pd.DataFrame(list(data_from_db))
- if '_id' in df.columns:
- df = df.drop(columns=['_id'])
- df_list.append(df)
- client.close()
- return df_list
- def insert_data_into_mongo(res_df, args):
- """
- 插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
- 参数:
- - res_df: 要插入的 DataFrame 数据
- - args: 包含 MongoDB 数据库和集合名称的字典
- - overwrite: 布尔值,True 表示覆盖,False 表示追加
- - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
- """
- mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
- mongodb_database = args['mongodb_database']
- mongodb_write_table = args['mongodb_write_table']
- overwrite = 1
- update_keys = None
- if 'overwrite' in args.keys():
- overwrite = int(args['overwrite'])
- if 'update_keys' in args.keys():
- update_keys = args['update_keys'].split(',')
- client = MongoClient(mongodb_connection)
- db = client[mongodb_database]
- collection = db[mongodb_write_table]
- # 覆盖模式:删除现有集合
- if overwrite:
- if mongodb_write_table in db.list_collection_names():
- collection.drop()
- print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!")
- # 将 DataFrame 转为字典格式
- data_dict = res_df.to_dict("records") # 每一行作为一个字典
- # 如果没有数据,直接返回
- if not data_dict:
- print("No data to insert.")
- return
- # 如果指定了 update_keys,则执行 upsert(更新或插入)
- if update_keys and not overwrite:
- operations = []
- for record in data_dict:
- # 构建查询条件,用于匹配要更新的文档
- query = {key: record[key] for key in update_keys}
- operations.append(UpdateOne(query, {'$set': record}, upsert=True))
- # 批量执行更新/插入操作
- if operations:
- result = collection.bulk_write(operations)
- print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
- else:
- # 追加模式:直接插入新数据
- collection.insert_many(data_dict)
- print("Data inserted successfully!")
- def get_data_fromMysql(params):
- mysql_conn = params['mysql_conn']
- query_sql = params['query_sql']
- #数据库读取实测气象
- engine = create_engine(f"mysql+pymysql://{mysql_conn}")
- # 定义SQL查询
- env_df = pd.read_sql_query(query_sql, engine)
- return env_df
- def insert_pickle_model_into_mongo(model, args):
- mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
- args['mongodb_database'], args['mongodb_write_table'], args['model_name']
- client = MongoClient(mongodb_connection)
- db = client[mongodb_database]
- # 序列化模型
- model_bytes = pickle.dumps(model)
- model_data = {
- 'model_name': model_name,
- 'model': model_bytes, # 将模型字节流存入数据库
- }
- print('Training completed!')
- if mongodb_write_table in db.list_collection_names():
- db[mongodb_write_table].drop()
- print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
- collection = db[mongodb_write_table] # 集合名称
- collection.insert_one(model_data)
- print("model inserted successfully!")
- def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any]) -> str:
- """
- 将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型
- 参数:
- model : keras模型 - 训练好的Keras模型
- args : dict - 包含以下键的字典:
- mongodb_database: 数据库名称
- model_table: 集合名称
- model_name: 模型名称
- gen_time: 模型生成时间(datetime对象)
- params: 模型参数(JSON可序列化对象)
- descr: 模型描述文本
- """
- # ------------------------- 参数校验 -------------------------
- required_keys = {'mongodb_database', 'model_table', 'model_name',
- 'gen_time', 'params', 'descr'}
- if missing := required_keys - args.keys():
- raise ValueError(f"缺少必要参数: {missing}")
- # ------------------------- 配置解耦 -------------------------
- # 从环境变量获取连接信息(更安全)
- mongodb_connection = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
- # ------------------------- 资源初始化 -------------------------
- fd, temp_path = None, None
- client = None
- try:
- # ------------------------- 临时文件处理 -------------------------
- fd, temp_path = tempfile.mkstemp(suffix='.keras')
- os.close(fd) # 立即释放文件锁
- # ------------------------- 模型保存 -------------------------
- try:
- model.save(temp_path) # 不指定save_format,默认使用keras新格式
- except Exception as e:
- raise RuntimeError(f"模型保存失败: {str(e)}") from e
- # ------------------------- 数据库连接 -------------------------
- client = MongoClient(mongodb_connection)
- db = client[args['mongodb_database']]
- collection = db[args['model_table']]
- # ------------------------- 索引检查 -------------------------
- # index_info = collection.index_information()
- # if "gen_time_1" not in index_info:
- # print("开始创建索引...")
- # collection.create_index(
- # [("gen_time", ASCENDING)],
- # name="gen_time_1",
- # background=True
- # )
- # print("索引创建成功")
- # else:
- # print("索引已存在,跳过创建")
- # ------------------------- 容量控制 -------------------------
- # 使用更高效的计数方式
- if collection.estimated_document_count() >= 50:
- # 原子性删除操作
- if deleted := collection.find_one_and_delete(
- sort=[("gen_time", ASCENDING)],
- projection={"_id": 1, "model_name": 1, "gen_time": 1}
- ):
- print(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
- # ------------------------- 数据插入 -------------------------
- with open(temp_path, 'rb') as f:
- result = collection.insert_one({
- "model_name": args['model_name'],
- "model_data": f.read(),
- "gen_time": args['gen_time'],
- "params": args['params'],
- "descr": args['descr']
- })
- print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
- return str(result.inserted_id)
- except Exception as e:
- # ------------------------- 异常分类处理 -------------------------
- error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
- print(f"❌ {error_type} - 详细错误: {str(e)}")
- raise # 根据业务需求决定是否重新抛出
- finally:
- # ------------------------- 资源清理 -------------------------
- if client:
- client.close()
- if temp_path and os.path.exists(temp_path):
- try:
- os.remove(temp_path)
- except PermissionError:
- print(f"⚠️ 临时文件清理失败: {temp_path}")
- def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO, args: Dict[str, Any]) -> str:
- """
- 将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档
- 参数:
- feature_scaler_bytes: BytesIO - 特征缩放器字节流
- scaled_target_bytes: BytesIO - 目标缩放器字节流
- args : dict - 包含以下键的字典:
- mongodb_database: 数据库名称
- scaler_table: 集合名称
- model_name: 关联模型名称
- gen_time: 生成时间(datetime对象)
- """
- # ------------------------- 参数校验 -------------------------
- required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'}
- if missing := required_keys - args.keys():
- raise ValueError(f"缺少必要参数: {missing}")
- # ------------------------- 配置解耦 -------------------------
- # 从环境变量获取连接信息(安全隔离凭证)
- mongodb_conn = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
- # ------------------------- 输入验证 -------------------------
- for buf, name in [(feature_scaler_bytes, "特征缩放器"),
- (target_scaler_bytes, "目标缩放器")]:
- if not isinstance(buf, BytesIO):
- raise TypeError(f"{name} 必须为BytesIO类型")
- if buf.getbuffer().nbytes == 0:
- raise ValueError(f"{name} 字节流为空")
- client = None
- try:
- # ------------------------- 数据库连接 -------------------------
- client = MongoClient(mongodb_conn)
- db = client[args['mongodb_database']]
- collection = db[args['scaler_table']]
- # ------------------------- 索引维护 -------------------------
- # if "gen_time_1" not in collection.index_information():
- # collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
- # print("⏱️ 已创建时间排序索引")
- # ------------------------- 容量控制 -------------------------
- # 使用近似计数提升性能(误差在几十条内可接受)
- if collection.estimated_document_count() >= 50:
- # 原子性删除操作(保证事务完整性)
- if deleted := collection.find_one_and_delete(
- sort=[("gen_time", ASCENDING)],
- projection={"_id": 1, "model_name": 1, "gen_time": 1}
- ):
- print(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}")
- # ------------------------- 数据插入 -------------------------
- # 确保字节流指针位置正确
- feature_scaler_bytes.seek(0)
- target_scaler_bytes.seek(0)
- result = collection.insert_one({
- "model_name": args['model_name'],
- "gen_time": args['gen_time'],
- "feature_scaler": feature_scaler_bytes.read(),
- "target_scaler": target_scaler_bytes.read()
- })
- print(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
- return str(result.inserted_id)
- except Exception as e:
- # ------------------------- 异常分类处理 -------------------------
- error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误"
- print(f"❌ {error_type}异常 - 详细错误: {str(e)}")
- raise # 根据业务需求决定是否重新抛出
- finally:
- # ------------------------- 资源清理 -------------------------
- if client:
- client.close()
- # 重置字节流指针(确保后续可复用)
- feature_scaler_bytes.seek(0)
- target_scaler_bytes.seek(0)
- def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[tf.keras.Model]:
- """
- 从MongoDB获取指定模型的最新版本
- 参数:
- args : dict - 包含以下键的字典:
- mongodb_database: 数据库名称
- model_table: 集合名称
- model_name: 要获取的模型名称
- custom_objects: dict - 自定义Keras对象字典
- 返回:
- tf.keras.Model - 加载成功的Keras模型
- """
- # ------------------------- 参数校验 -------------------------
- required_keys = {'mongodb_database', 'model_table', 'model_name'}
- if missing := required_keys - args.keys():
- raise ValueError(f"❌ 缺失必要参数: {missing}")
- # ------------------------- 环境配置 -------------------------
- mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
- client = None
- tmp_file_path = None # 用于跟踪临时文件路径
- try:
- # ------------------------- 数据库连接 -------------------------
- client = MongoClient(
- mongo_uri,
- maxPoolSize=10, # 连接池优化
- socketTimeoutMS=5000
- )
- db = client[args['mongodb_database']]
- collection = db[args['model_table']]
- # ------------------------- 索引维护 -------------------------
- index_name = "model_gen_time_idx"
- if index_name not in collection.index_information():
- collection.create_index(
- [("model_name", 1), ("gen_time", DESCENDING)],
- name=index_name
- )
- print("⏱️ 已创建复合索引")
- # ------------------------- 高效查询 -------------------------
- model_doc = collection.find_one(
- {"model_name": args['model_name']},
- sort=[('gen_time', DESCENDING)],
- projection={"model_data": 1, "gen_time": 1} # 获取必要字段
- )
- if not model_doc:
- print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
- return None
- # ------------------------- 内存优化加载 -------------------------
- if model_doc:
- model_data = model_doc['model_data'] # 获取模型的二进制数据
- # # 将二进制数据加载到 BytesIO 缓冲区
- # model_buffer = BytesIO(model_data)
- # # 确保指针在起始位置
- # model_buffer.seek(0)
- # # 从缓冲区加载模型
- # # 使用 h5py 和 BytesIO 从内存中加载模型
- # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
- # model = tf.keras.models.load_model(f, custom_objects=custom_objects)
- # 创建临时文件
- with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
- tmp_file.write(model_data)
- tmp_file_path = tmp_file.name # 获取临时文件路径
- # 从临时文件加载模型
- model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects)
- print(f"{args['model_name']}模型成功从 MongoDB 加载!")
- return model
- except tf.errors.NotFoundError as e:
- print(f"❌ 模型结构缺失关键组件: {str(e)}")
- raise RuntimeError("模型架构不完整") from e
- except Exception as e:
- print(f"❌ 系统异常: {str(e)}")
- raise
- finally:
- # ------------------------- 资源清理 -------------------------
- if client:
- client.close()
- # 确保删除临时文件
- if tmp_file_path and os.path.exists(tmp_file_path):
- try:
- os.remove(tmp_file_path)
- print(f"🧹 已清理临时文件: {tmp_file_path}")
- except Exception as cleanup_err:
- print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
- def get_keras_model_from_mongo(
- args: Dict[str, Any],
- custom_objects: Optional[Dict[str, Any]] = None
- ) -> Optional[tf.keras.Model]:
- """
- 从MongoDB获取指定模型的最新版本(支持Keras格式)
- 参数:
- args : dict - 包含以下键的字典:
- mongodb_database: 数据库名称
- model_table: 集合名称
- model_name: 要获取的模型名称
- custom_objects: dict - 自定义Keras对象字典
- 返回:
- tf.keras.Model - 加载成功的Keras模型
- """
- # ------------------------- 参数校验 -------------------------
- required_keys = {'mongodb_database', 'model_table', 'model_name'}
- if missing := required_keys - args.keys():
- raise ValueError(f"❌ 缺失必要参数: {missing}")
- # ------------------------- 环境配置 -------------------------
- mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
- client = None
- tmp_file_path = None # 用于跟踪临时文件路径
- try:
- # ------------------------- 数据库连接 -------------------------
- client = MongoClient(
- mongo_uri,
- maxPoolSize=10,
- socketTimeoutMS=5000
- )
- db = client[args['mongodb_database']]
- collection = db[args['model_table']]
- # ------------------------- 索引维护 -------------------------
- # index_name = "model_gen_time_idx"
- # if index_name not in collection.index_information():
- # collection.create_index(
- # [("model_name", 1), ("gen_time", DESCENDING)],
- # name=index_name
- # )
- # print("⏱️ 已创建复合索引")
- # ------------------------- 高效查询 -------------------------
- model_doc = collection.find_one(
- {"model_name": args['model_name']},
- sort=[('gen_time', DESCENDING)],
- projection={"model_data": 1, "gen_time": 1, 'params':1}
- )
- if not model_doc:
- print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
- return None
- # ------------------------- 内存优化加载 -------------------------
- model_data = model_doc['model_data']
- model_params = model_doc['params']
- # 创建临时文件(自动删除)
- with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
- tmp_file.write(model_data)
- tmp_file_path = tmp_file.name # 记录文件路径
- # 从临时文件加载模型
- model = tf.keras.models.load_model(
- tmp_file_path,
- custom_objects=custom_objects
- )
- print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
- return model, model_params
- except tf.errors.NotFoundError as e:
- print(f"❌ 模型结构缺失关键组件: {str(e)}")
- raise RuntimeError("模型架构不完整") from e
- except Exception as e:
- print(f"❌ 系统异常: {str(e)}")
- raise
- finally:
- # ------------------------- 资源清理 -------------------------
- if client:
- client.close()
- # 确保删除临时文件
- if tmp_file_path and os.path.exists(tmp_file_path):
- try:
- os.remove(tmp_file_path)
- print(f"🧹 已清理临时文件: {tmp_file_path}")
- except Exception as cleanup_err:
- print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
- def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]:
- """
- 优化版特征缩放器加载函数 - 安全高效获取最新预处理模型
- 参数:
- args : 必须包含键:
- - mongodb_database: 数据库名称
- - scaler_table: 集合名称
- - model_name: 目标模型名称
- only_feature_scaler : 是否仅返回特征缩放器
- 返回:
- 单个缩放器对象或(feature_scaler, target_scaler)元组
- 异常:
- ValueError : 参数缺失或类型错误
- RuntimeError : 数据操作异常
- """
- # ------------------------- 参数校验 -------------------------
- required_keys = {'mongodb_database', 'scaler_table', 'model_name'}
- if missing := required_keys - args.keys():
- raise ValueError(f"❌ 缺失必要参数: {missing}")
- # ------------------------- 环境配置 -------------------------
- mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
- client = None
- try:
- # ------------------------- 数据库连接 -------------------------
- client = MongoClient(
- mongo_uri,
- maxPoolSize=20, # 连接池上限
- socketTimeoutMS=3000, # 3秒超时
- serverSelectionTimeoutMS=5000 # 5秒服务器选择超时
- )
- db = client[args['mongodb_database']]
- collection = db[args['scaler_table']]
- # ------------------------- 索引维护 -------------------------
- # index_name = "model_gen_time_idx"
- # if index_name not in collection.index_information():
- # collection.create_index(
- # [("model_name", 1), ("gen_time", DESCENDING)],
- # name=index_name,
- # background=True # 后台构建避免阻塞
- # )
- # print("⏱️ 已创建特征缩放器复合索引")
- # ------------------------- 高效查询 -------------------------
- scaler_doc = collection.find_one(
- {"model_name": args['model_name']},
- sort=[('gen_time', DESCENDING)],
- projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1}
- )
- if not scaler_doc:
- raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录")
- # ------------------------- 反序列化处理 -------------------------
- def load_scaler(data: bytes) -> object:
- """安全加载序列化对象"""
- with BytesIO(data) as buffer:
- buffer.seek(0) # 确保指针复位
- try:
- return joblib.load(buffer)
- except joblib.UnpicklingError as e:
- raise RuntimeError("反序列化失败 (可能版本不兼容)") from e
- # 特征缩放器加载
- feature_data = scaler_doc["feature_scaler"]
- if not isinstance(feature_data, bytes):
- raise RuntimeError("特征缩放器数据格式异常")
- feature_scaler = load_scaler(feature_data)
- if only_feature_scaler:
- return feature_scaler
- # 目标缩放器加载
- target_data = scaler_doc["target_scaler"]
- if not isinstance(target_data, bytes):
- raise RuntimeError("目标缩放器数据格式异常")
- target_scaler = load_scaler(target_data)
- print(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})")
- return feature_scaler, target_scaler
- except PyMongoError as e:
- raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e
- except RuntimeError as e:
- raise RuntimeError(f"🔌 mongo操作失败: {str(e)}") from e# 直接传递已封装的异常
- except Exception as e:
- raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
- finally:
- # ------------------------- 资源清理 -------------------------
- if client:
- client.close()
|