import pymongo from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING from pymongo.errors import PyMongoError import pandas as pd from sqlalchemy import create_engine import pickle from io import BytesIO import joblib import h5py, os, io import tensorflow as tf from typing import Dict, Any, Optional, Union, Tuple import tempfile def get_data_from_mongo(args): mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/" mongodb_database = args['mongodb_database'] mongodb_read_table = args['mongodb_read_table'] query_dict = {} if 'timeBegin' in args.keys(): timeBegin = args['timeBegin'] query_dict.update({"$gte": timeBegin}) if 'timeEnd' in args.keys(): timeEnd = args['timeEnd'] query_dict.update({"$lte": timeEnd}) client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] collection = db[mongodb_read_table] # 集合名称 if len(query_dict) != 0: query = {"dateTime": query_dict} cursor = collection.find(query) else: cursor = collection.find() data = list(cursor) df = pd.DataFrame(data) # 4. 删除 _id 字段(可选) if '_id' in df.columns: df = df.drop(columns=['_id']) client.close() return df def get_df_list_from_mongo(args): mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',') df_list = [] client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] for table in mongodb_read_table: collection = db[table] # 集合名称 data_from_db = collection.find() # 这会返回一个游标(cursor) # 将游标转换为列表,并创建 pandas DataFrame df = pd.DataFrame(list(data_from_db)) if '_id' in df.columns: df = df.drop(columns=['_id']) df_list.append(df) client.close() return df_list def insert_data_into_mongo(res_df, args): """ 插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。 参数: - res_df: 要插入的 DataFrame 数据 - args: 包含 MongoDB 数据库和集合名称的字典 - overwrite: 布尔值,True 表示覆盖,False 表示追加 - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2' """ mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/" mongodb_database = args['mongodb_database'] mongodb_write_table = args['mongodb_write_table'] overwrite = 1 update_keys = None if 'overwrite' in args.keys(): overwrite = int(args['overwrite']) if 'update_keys' in args.keys(): update_keys = args['update_keys'].split(',') client = MongoClient(mongodb_connection) db = client[mongodb_database] collection = db[mongodb_write_table] # 覆盖模式:删除现有集合 if overwrite: if mongodb_write_table in db.list_collection_names(): collection.drop() print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!") # 将 DataFrame 转为字典格式 data_dict = res_df.to_dict("records") # 每一行作为一个字典 # 如果没有数据,直接返回 if not data_dict: print("No data to insert.") return # 如果指定了 update_keys,则执行 upsert(更新或插入) if update_keys and not overwrite: operations = [] for record in data_dict: # 构建查询条件,用于匹配要更新的文档 query = {key: record[key] for key in update_keys} operations.append(UpdateOne(query, {'$set': record}, upsert=True)) # 批量执行更新/插入操作 if operations: result = collection.bulk_write(operations) print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}") else: # 追加模式:直接插入新数据 collection.insert_many(data_dict) print("Data inserted successfully!") def get_data_fromMysql(params): mysql_conn = params['mysql_conn'] query_sql = params['query_sql'] #数据库读取实测气象 engine = create_engine(f"mysql+pymysql://{mysql_conn}") # 定义SQL查询 env_df = pd.read_sql_query(query_sql, engine) return env_df def insert_pickle_model_into_mongo(model, args): mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \ args['mongodb_database'], args['mongodb_write_table'], args['model_name'] client = MongoClient(mongodb_connection) db = client[mongodb_database] # 序列化模型 model_bytes = pickle.dumps(model) model_data = { 'model_name': model_name, 'model': model_bytes, # 将模型字节流存入数据库 } print('Training completed!') if mongodb_write_table in db.list_collection_names(): db[mongodb_write_table].drop() print(f"Collection '{mongodb_write_table} already exist, deleted successfully!") collection = db[mongodb_write_table] # 集合名称 collection.insert_one(model_data) print("model inserted successfully!") def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any]) -> str: """ 将训练好的H5模型插入MongoDB,自动维护集合容量不超过50个模型 参数: model : keras模型 - 训练好的Keras模型 args : dict - 包含以下键的字典: mongodb_database: 数据库名称 model_table: 集合名称 model_name: 模型名称 gen_time: 模型生成时间(datetime对象) params: 模型参数(JSON可序列化对象) descr: 模型描述文本 """ # ------------------------- 参数校验 ------------------------- required_keys = {'mongodb_database', 'model_table', 'model_name', 'gen_time', 'params', 'descr'} if missing := required_keys - args.keys(): raise ValueError(f"缺少必要参数: {missing}") # ------------------------- 配置解耦 ------------------------- # 从环境变量获取连接信息(更安全) mongodb_connection = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/") # ------------------------- 资源初始化 ------------------------- fd, temp_path = None, None client = None try: # ------------------------- 临时文件处理 ------------------------- fd, temp_path = tempfile.mkstemp(suffix='.keras') os.close(fd) # 立即释放文件锁 # ------------------------- 模型保存 ------------------------- try: model.save(temp_path) # 不指定save_format,默认使用keras新格式 except Exception as e: raise RuntimeError(f"模型保存失败: {str(e)}") from e # ------------------------- 数据库连接 ------------------------- client = MongoClient(mongodb_connection) db = client[args['mongodb_database']] collection = db[args['model_table']] # ------------------------- 索引检查 ------------------------- # index_info = collection.index_information() # if "gen_time_1" not in index_info: # print("开始创建索引...") # collection.create_index( # [("gen_time", ASCENDING)], # name="gen_time_1", # background=True # ) # print("索引创建成功") # else: # print("索引已存在,跳过创建") # ------------------------- 容量控制 ------------------------- # 使用更高效的计数方式 if collection.estimated_document_count() >= 50: # 原子性删除操作 if deleted := collection.find_one_and_delete( sort=[("gen_time", ASCENDING)], projection={"_id": 1, "model_name": 1, "gen_time": 1} ): print(f"已淘汰模型 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}") # ------------------------- 数据插入 ------------------------- with open(temp_path, 'rb') as f: result = collection.insert_one({ "model_name": args['model_name'], "model_data": f.read(), "gen_time": args['gen_time'], "params": args['params'], "descr": args['descr'] }) print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}") return str(result.inserted_id) except Exception as e: # ------------------------- 异常分类处理 ------------------------- error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误" print(f"❌ {error_type} - 详细错误: {str(e)}") raise # 根据业务需求决定是否重新抛出 finally: # ------------------------- 资源清理 ------------------------- if client: client.close() if temp_path and os.path.exists(temp_path): try: os.remove(temp_path) except PermissionError: print(f"⚠️ 临时文件清理失败: {temp_path}") def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_bytes: BytesIO, args: Dict[str, Any]) -> str: """ 将特征缩放器存储到MongoDB,自动维护集合容量不超过50个文档 参数: feature_scaler_bytes: BytesIO - 特征缩放器字节流 scaled_target_bytes: BytesIO - 目标缩放器字节流 args : dict - 包含以下键的字典: mongodb_database: 数据库名称 scaler_table: 集合名称 model_name: 关联模型名称 gen_time: 生成时间(datetime对象) """ # ------------------------- 参数校验 ------------------------- required_keys = {'mongodb_database', 'scaler_table', 'model_name', 'gen_time'} if missing := required_keys - args.keys(): raise ValueError(f"缺少必要参数: {missing}") # ------------------------- 配置解耦 ------------------------- # 从环境变量获取连接信息(安全隔离凭证) mongodb_conn = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/") # ------------------------- 输入验证 ------------------------- for buf, name in [(feature_scaler_bytes, "特征缩放器"), (target_scaler_bytes, "目标缩放器")]: if not isinstance(buf, BytesIO): raise TypeError(f"{name} 必须为BytesIO类型") if buf.getbuffer().nbytes == 0: raise ValueError(f"{name} 字节流为空") client = None try: # ------------------------- 数据库连接 ------------------------- client = MongoClient(mongodb_conn) db = client[args['mongodb_database']] collection = db[args['scaler_table']] # ------------------------- 索引维护 ------------------------- # if "gen_time_1" not in collection.index_information(): # collection.create_index([("gen_time", ASCENDING)], name="gen_time_1") # print("⏱️ 已创建时间排序索引") # ------------------------- 容量控制 ------------------------- # 使用近似计数提升性能(误差在几十条内可接受) if collection.estimated_document_count() >= 50: # 原子性删除操作(保证事务完整性) if deleted := collection.find_one_and_delete( sort=[("gen_time", ASCENDING)], projection={"_id": 1, "model_name": 1, "gen_time": 1} ): print(f"🗑️ 已淘汰最旧缩放器 [{deleted['model_name']}] 生成时间: {deleted['gen_time']}") # ------------------------- 数据插入 ------------------------- # 确保字节流指针位置正确 feature_scaler_bytes.seek(0) target_scaler_bytes.seek(0) result = collection.insert_one({ "model_name": args['model_name'], "gen_time": args['gen_time'], "feature_scaler": feature_scaler_bytes.read(), "target_scaler": target_scaler_bytes.read() }) print(f"✅ 缩放器 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}") return str(result.inserted_id) except Exception as e: # ------------------------- 异常分类处理 ------------------------- error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, ValueError)) else "系统错误" print(f"❌ {error_type}异常 - 详细错误: {str(e)}") raise # 根据业务需求决定是否重新抛出 finally: # ------------------------- 资源清理 ------------------------- if client: client.close() # 重置字节流指针(确保后续可复用) feature_scaler_bytes.seek(0) target_scaler_bytes.seek(0) def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None) -> Optional[tf.keras.Model]: """ 从MongoDB获取指定模型的最新版本 参数: args : dict - 包含以下键的字典: mongodb_database: 数据库名称 model_table: 集合名称 model_name: 要获取的模型名称 custom_objects: dict - 自定义Keras对象字典 返回: tf.keras.Model - 加载成功的Keras模型 """ # ------------------------- 参数校验 ------------------------- required_keys = {'mongodb_database', 'model_table', 'model_name'} if missing := required_keys - args.keys(): raise ValueError(f"❌ 缺失必要参数: {missing}") # ------------------------- 环境配置 ------------------------- mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/") client = None tmp_file_path = None # 用于跟踪临时文件路径 try: # ------------------------- 数据库连接 ------------------------- client = MongoClient( mongo_uri, maxPoolSize=10, # 连接池优化 socketTimeoutMS=5000 ) db = client[args['mongodb_database']] collection = db[args['model_table']] # ------------------------- 索引维护 ------------------------- index_name = "model_gen_time_idx" if index_name not in collection.index_information(): collection.create_index( [("model_name", 1), ("gen_time", DESCENDING)], name=index_name ) print("⏱️ 已创建复合索引") # ------------------------- 高效查询 ------------------------- model_doc = collection.find_one( {"model_name": args['model_name']}, sort=[('gen_time', DESCENDING)], projection={"model_data": 1, "gen_time": 1} # 获取必要字段 ) if not model_doc: print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录") return None # ------------------------- 内存优化加载 ------------------------- if model_doc: model_data = model_doc['model_data'] # 获取模型的二进制数据 # # 将二进制数据加载到 BytesIO 缓冲区 # model_buffer = BytesIO(model_data) # # 确保指针在起始位置 # model_buffer.seek(0) # # 从缓冲区加载模型 # # 使用 h5py 和 BytesIO 从内存中加载模型 # with h5py.File(model_buffer, 'r', driver='fileobj') as f: # model = tf.keras.models.load_model(f, custom_objects=custom_objects) # 创建临时文件 with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file: tmp_file.write(model_data) tmp_file_path = tmp_file.name # 获取临时文件路径 # 从临时文件加载模型 model = tf.keras.models.load_model(tmp_file_path, custom_objects=custom_objects) print(f"{args['model_name']}模型成功从 MongoDB 加载!") return model except tf.errors.NotFoundError as e: print(f"❌ 模型结构缺失关键组件: {str(e)}") raise RuntimeError("模型架构不完整") from e except Exception as e: print(f"❌ 系统异常: {str(e)}") raise finally: # ------------------------- 资源清理 ------------------------- if client: client.close() # 确保删除临时文件 if tmp_file_path and os.path.exists(tmp_file_path): try: os.remove(tmp_file_path) print(f"🧹 已清理临时文件: {tmp_file_path}") except Exception as cleanup_err: print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}") def get_keras_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict[str, Any]] = None ) -> Optional[tf.keras.Model]: """ 从MongoDB获取指定模型的最新版本(支持Keras格式) 参数: args : dict - 包含以下键的字典: mongodb_database: 数据库名称 model_table: 集合名称 model_name: 要获取的模型名称 custom_objects: dict - 自定义Keras对象字典 返回: tf.keras.Model - 加载成功的Keras模型 """ # ------------------------- 参数校验 ------------------------- required_keys = {'mongodb_database', 'model_table', 'model_name'} if missing := required_keys - args.keys(): raise ValueError(f"❌ 缺失必要参数: {missing}") # ------------------------- 环境配置 ------------------------- mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/") client = None tmp_file_path = None # 用于跟踪临时文件路径 try: # ------------------------- 数据库连接 ------------------------- client = MongoClient( mongo_uri, maxPoolSize=10, socketTimeoutMS=5000 ) db = client[args['mongodb_database']] collection = db[args['model_table']] # ------------------------- 索引维护 ------------------------- # index_name = "model_gen_time_idx" # if index_name not in collection.index_information(): # collection.create_index( # [("model_name", 1), ("gen_time", DESCENDING)], # name=index_name # ) # print("⏱️ 已创建复合索引") # ------------------------- 高效查询 ------------------------- model_doc = collection.find_one( {"model_name": args['model_name']}, sort=[('gen_time', DESCENDING)], projection={"model_data": 1, "gen_time": 1, 'params':1} ) if not model_doc: print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录") return None # ------------------------- 内存优化加载 ------------------------- model_data = model_doc['model_data'] model_params = model_doc['params'] # 创建临时文件(自动删除) with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file: tmp_file.write(model_data) tmp_file_path = tmp_file.name # 记录文件路径 # 从临时文件加载模型 model = tf.keras.models.load_model( tmp_file_path, custom_objects=custom_objects ) print(f"{args['model_name']} 模型成功从 MongoDB 加载!") return model, model_params except tf.errors.NotFoundError as e: print(f"❌ 模型结构缺失关键组件: {str(e)}") raise RuntimeError("模型架构不完整") from e except Exception as e: print(f"❌ 系统异常: {str(e)}") raise finally: # ------------------------- 资源清理 ------------------------- if client: client.close() # 确保删除临时文件 if tmp_file_path and os.path.exists(tmp_file_path): try: os.remove(tmp_file_path) print(f"🧹 已清理临时文件: {tmp_file_path}") except Exception as cleanup_err: print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}") def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]: """ 优化版特征缩放器加载函数 - 安全高效获取最新预处理模型 参数: args : 必须包含键: - mongodb_database: 数据库名称 - scaler_table: 集合名称 - model_name: 目标模型名称 only_feature_scaler : 是否仅返回特征缩放器 返回: 单个缩放器对象或(feature_scaler, target_scaler)元组 异常: ValueError : 参数缺失或类型错误 RuntimeError : 数据操作异常 """ # ------------------------- 参数校验 ------------------------- required_keys = {'mongodb_database', 'scaler_table', 'model_name'} if missing := required_keys - args.keys(): raise ValueError(f"❌ 缺失必要参数: {missing}") # ------------------------- 环境配置 ------------------------- mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/") client = None try: # ------------------------- 数据库连接 ------------------------- client = MongoClient( mongo_uri, maxPoolSize=20, # 连接池上限 socketTimeoutMS=3000, # 3秒超时 serverSelectionTimeoutMS=5000 # 5秒服务器选择超时 ) db = client[args['mongodb_database']] collection = db[args['scaler_table']] # ------------------------- 索引维护 ------------------------- # index_name = "model_gen_time_idx" # if index_name not in collection.index_information(): # collection.create_index( # [("model_name", 1), ("gen_time", DESCENDING)], # name=index_name, # background=True # 后台构建避免阻塞 # ) # print("⏱️ 已创建特征缩放器复合索引") # ------------------------- 高效查询 ------------------------- scaler_doc = collection.find_one( {"model_name": args['model_name']}, sort=[('gen_time', DESCENDING)], projection={"feature_scaler": 1, "target_scaler": 1, "gen_time": 1} ) if not scaler_doc: raise RuntimeError(f"⚠️ 找不到模型 {args['model_name']} 的缩放器记录") # ------------------------- 反序列化处理 ------------------------- def load_scaler(data: bytes) -> object: """安全加载序列化对象""" with BytesIO(data) as buffer: buffer.seek(0) # 确保指针复位 try: return joblib.load(buffer) except joblib.UnpicklingError as e: raise RuntimeError("反序列化失败 (可能版本不兼容)") from e # 特征缩放器加载 feature_data = scaler_doc["feature_scaler"] if not isinstance(feature_data, bytes): raise RuntimeError("特征缩放器数据格式异常") feature_scaler = load_scaler(feature_data) if only_feature_scaler: return feature_scaler # 目标缩放器加载 target_data = scaler_doc["target_scaler"] if not isinstance(target_data, bytes): raise RuntimeError("目标缩放器数据格式异常") target_scaler = load_scaler(target_data) print(f"✅ 成功加载 {args['model_name']} 的缩放器 (版本时间: {scaler_doc.get('gen_time', '未知')})") return feature_scaler, target_scaler except PyMongoError as e: raise RuntimeError(f"🔌 数据库操作失败: {str(e)}") from e except RuntimeError as e: raise RuntimeError(f"🔌 mongo操作失败: {str(e)}") from e# 直接传递已封装的异常 except Exception as e: raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e finally: # ------------------------- 资源清理 ------------------------- if client: client.close()