from pymongo import MongoClient, UpdateOne import pandas as pd from sqlalchemy import create_engine import pickle from io import BytesIO import joblib import h5py import tensorflow as tf def get_data_from_mongo(args): mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/" mongodb_database = args['mongodb_database'] mongodb_read_table = args['mongodb_read_table'] query_dict = {} if 'timeBegin' in args.keys(): timeBegin = args['timeBegin'] query_dict.update({"$gte": timeBegin}) if 'timeEnd' in args.keys(): timeEnd = args['timeEnd'] query_dict.update({"$lte": timeEnd}) client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] collection = db[mongodb_read_table] # 集合名称 if len(query_dict) != 0: query = {"dateTime": query_dict} cursor = collection.find(query) else: cursor = collection.find() data = list(cursor) df = pd.DataFrame(data) # 4. 删除 _id 字段(可选) if '_id' in df.columns: df = df.drop(columns=['_id']) client.close() return df def get_df_list_from_mongo(args): mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',') df_list = [] client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] for table in mongodb_read_table: collection = db[table] # 集合名称 data_from_db = collection.find() # 这会返回一个游标(cursor) # 将游标转换为列表,并创建 pandas DataFrame df = pd.DataFrame(list(data_from_db)) if '_id' in df.columns: df = df.drop(columns=['_id']) df_list.append(df) client.close() return df_list def insert_data_into_mongo(res_df, args): """ 插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。 参数: - res_df: 要插入的 DataFrame 数据 - args: 包含 MongoDB 数据库和集合名称的字典 - overwrite: 布尔值,True 表示覆盖,False 表示追加 - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2' """ mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/" mongodb_database = args['mongodb_database'] mongodb_write_table = args['mongodb_write_table'] overwrite = 1 update_keys = None if 'overwrite' in args.keys(): overwrite = int(args['overwrite']) if 'update_keys' in args.keys(): update_keys = args['update_keys'].split(',') client = MongoClient(mongodb_connection) db = client[mongodb_database] collection = db[mongodb_write_table] # 覆盖模式:删除现有集合 if overwrite: if mongodb_write_table in db.list_collection_names(): collection.drop() print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!") # 将 DataFrame 转为字典格式 data_dict = res_df.to_dict("records") # 每一行作为一个字典 # 如果没有数据,直接返回 if not data_dict: print("No data to insert.") return # 如果指定了 update_keys,则执行 upsert(更新或插入) if update_keys and not overwrite: operations = [] for record in data_dict: # 构建查询条件,用于匹配要更新的文档 query = {key: record[key] for key in update_keys} operations.append(UpdateOne(query, {'$set': record}, upsert=True)) # 批量执行更新/插入操作 if operations: result = collection.bulk_write(operations) print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}") else: # 追加模式:直接插入新数据 collection.insert_many(data_dict) print("Data inserted successfully!") def get_data_fromMysql(params): mysql_conn = params['mysql_conn'] query_sql = params['query_sql'] #数据库读取实测气象 engine = create_engine(f"mysql+pymysql://{mysql_conn}") # 定义SQL查询 env_df = pd.read_sql_query(query_sql, engine) return env_df def insert_pickle_model_into_mongo(model, args): mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \ args['mongodb_database'], args['mongodb_write_table'], args['model_name'] client = MongoClient(mongodb_connection) db = client[mongodb_database] # 序列化模型 model_bytes = pickle.dumps(model) model_data = { 'model_name': model_name, 'model': model_bytes, # 将模型字节流存入数据库 } print('Training completed!') if mongodb_write_table in db.list_collection_names(): db[mongodb_write_table].drop() print(f"Collection '{mongodb_write_table} already exist, deleted successfully!") collection = db[mongodb_write_table] # 集合名称 collection.insert_one(model_data) print("model inserted successfully!") def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,args): mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name']) client = MongoClient(mongodb_connection) db = client[mongodb_database] if scaler_table in db.list_collection_names(): db[scaler_table].drop() print(f"Collection '{scaler_table} already exist, deleted successfully!") collection = db[scaler_table] # 集合名称 # Save the scalers in MongoDB as binary data collection.insert_one({ "feature_scaler": feature_scaler_bytes.read(), "target_scaler": target_scaler_bytes.read() }) print("scaler_model inserted successfully!") if model_table in db.list_collection_names(): db[model_table].drop() print(f"Collection '{model_table} already exist, deleted successfully!") model_table = db[model_table] # 创建 BytesIO 缓冲区 model_buffer = BytesIO() # 将模型保存为 HDF5 格式到内存 (BytesIO) model.save(model_buffer, save_format='h5') # 将指针移到缓冲区的起始位置 model_buffer.seek(0) # 获取模型的二进制数据 model_data = model_buffer.read() # 将模型保存到 MongoDB model_table.insert_one({ "model_name": model_name, "model_data": model_data }) print("模型成功保存到 MongoDB!") def insert_trained_model_into_mongo(model ,args): mongodb_connection,mongodb_database,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'],args['model_table'],args['model_name']) gen_time, params_json, descr = args['gen_time'], args['params'], args['descr'] client = MongoClient(mongodb_connection) db = client[mongodb_database] if model_table in db.list_collection_names(): db[model_table].drop() print(f"Collection '{model_table} already exist, deleted successfully!") model_table = db[model_table] # 创建 BytesIO 缓冲区 model_buffer = BytesIO() # 将模型保存为 HDF5 格式到内存 (BytesIO) model.save(model_buffer, save_format='h5') # 将指针移到缓冲区的起始位置 model_buffer.seek(0) # 获取模型的二进制数据 model_data = model_buffer.read() # 将模型保存到 MongoDB model_table.insert_one({ "model_name": model_name, "model_data": model_data, "gen_time": gen_time, "params": params_json, "descr": descr }) print("模型成功保存到 MongoDB!") def insert_scaler_model_into_mongo(feature_scaler_bytes, args): mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name']) client = MongoClient(mongodb_connection) db = client[mongodb_database] if scaler_table in db.list_collection_names(): db[scaler_table].drop() print(f"Collection '{scaler_table} already exist, deleted successfully!") collection = db[scaler_table] # 集合名称 # Save the scalers in MongoDB as binary data collection.insert_one({ "feature_scaler": feature_scaler_bytes.read(), }) print("scaler_model inserted successfully!") def get_h5_model_from_mongo(args): mongodb_connection,mongodb_database,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['model_table'],args['model_name'] client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] collection = db[model_table] # 集合名称 # 查询 MongoDB 获取模型数据 model_doc = collection.find_one({"model_name": model_name}) if model_doc: model_data = model_doc['model_data'] # 获取模型的二进制数据 # 将二进制数据加载到 BytesIO 缓冲区 model_buffer = BytesIO(model_data) # 从缓冲区加载模型 # 使用 h5py 和 BytesIO 从内存中加载模型 with h5py.File(model_buffer, 'r') as f: model = tf.keras.models.load_model(f) print(f"{model_name}模型成功从 MongoDB 加载!") client.close() return model else: print(f"未找到model_name为 {model_name} 的模型。") client.close() return None def get_scaler_model_from_mongo(args): mongodb_connection, mongodb_database, scaler_table, = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'], args['scaler_table']) client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] collection = db[scaler_table] # 集合名称 # Retrieve the scalers from MongoDB scaler_doc = collection.find_one() # Deserialize the scalers feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"]) feature_scaler = joblib.load(feature_scaler_bytes) target_scaler_bytes = BytesIO(scaler_doc["target_scaler"]) target_scaler = joblib.load(target_scaler_bytes) return feature_scaler,target_scaler