David 2 kuukautta sitten
vanhempi
commit
676ef6b3c2
1 muutettua tiedostoa jossa 273 lisäystä ja 0 poistoa
  1. 273 0
      common/database_dml_koi.py

+ 273 - 0
common/database_dml_koi.py

@@ -0,0 +1,273 @@
+from pymongo import MongoClient, UpdateOne, DESCENDING
+import pandas as pd
+from sqlalchemy import create_engine
+import pickle
+from io import BytesIO
+import joblib
+import h5py
+import tensorflow as tf
+
+def get_data_from_mongo(args):
+    mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
+    mongodb_database = args['mongodb_database']
+    mongodb_read_table = args['mongodb_read_table']
+    query_dict = {}
+    if 'timeBegin' in args.keys():
+        timeBegin = args['timeBegin']
+        query_dict.update({"$gte": timeBegin})
+    if 'timeEnd' in args.keys():
+        timeEnd = args['timeEnd']
+        query_dict.update({"$lte": timeEnd})
+
+    client = MongoClient(mongodb_connection)
+    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
+    db = client[mongodb_database]
+    collection = db[mongodb_read_table]  # 集合名称
+    if len(query_dict) != 0:
+        query = {"dateTime": query_dict}
+        cursor = collection.find(query)
+    else:
+        cursor = collection.find()
+    data = list(cursor)
+    df = pd.DataFrame(data)
+    # 4. 删除 _id 字段(可选)
+    if '_id' in df.columns:
+        df = df.drop(columns=['_id'])
+    client.close()
+    return df
+
+
+def get_df_list_from_mongo(args):
+    mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',')
+    df_list = []
+    client = MongoClient(mongodb_connection)
+    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
+    db = client[mongodb_database]
+    for table in mongodb_read_table:
+        collection = db[table]  # 集合名称
+        data_from_db = collection.find()  # 这会返回一个游标(cursor)
+        # 将游标转换为列表,并创建 pandas DataFrame
+        df = pd.DataFrame(list(data_from_db))
+        if '_id' in df.columns:
+            df = df.drop(columns=['_id'])
+        df_list.append(df)
+    client.close()
+    return df_list
+
+def insert_data_into_mongo(res_df, args):
+    """
+    插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
+
+    参数:
+    - res_df: 要插入的 DataFrame 数据
+    - args: 包含 MongoDB 数据库和集合名称的字典
+    - overwrite: 布尔值,True 表示覆盖,False 表示追加
+    - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
+    """
+    mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
+    mongodb_database = args['mongodb_database']
+    mongodb_write_table = args['mongodb_write_table']
+    overwrite = 1
+    update_keys = None
+    if 'overwrite' in args.keys():
+        overwrite = int(args['overwrite'])
+    if 'update_keys' in args.keys():
+        update_keys = args['update_keys'].split(',')
+
+    client = MongoClient(mongodb_connection)
+    db = client[mongodb_database]
+    collection = db[mongodb_write_table]
+
+    # 覆盖模式:删除现有集合
+    if overwrite:
+        if mongodb_write_table in db.list_collection_names():
+            collection.drop()
+            print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!")
+
+    # 将 DataFrame 转为字典格式
+    data_dict = res_df.to_dict("records")  # 每一行作为一个字典
+
+    # 如果没有数据,直接返回
+    if not data_dict:
+        print("No data to insert.")
+        return
+
+    # 如果指定了 update_keys,则执行 upsert(更新或插入)
+    if update_keys and not overwrite:
+        operations = []
+        for record in data_dict:
+            # 构建查询条件,用于匹配要更新的文档
+            query = {key: record[key] for key in update_keys}
+            operations.append(UpdateOne(query, {'$set': record}, upsert=True))
+
+        # 批量执行更新/插入操作
+        if operations:
+            result = collection.bulk_write(operations)
+            print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
+    else:
+        # 追加模式:直接插入新数据
+        collection.insert_many(data_dict)
+        print("Data inserted successfully!")
+
+
+def get_data_fromMysql(params):
+    mysql_conn = params['mysql_conn']
+    query_sql = params['query_sql']
+    #数据库读取实测气象
+    engine = create_engine(f"mysql+pymysql://{mysql_conn}")
+    # 定义SQL查询
+    env_df = pd.read_sql_query(query_sql, engine)
+    return env_df
+
+
+def insert_pickle_model_into_mongo(model, args):
+    mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
+    args['mongodb_database'], args['mongodb_write_table'], args['model_name']
+    client = MongoClient(mongodb_connection)
+    db = client[mongodb_database]
+    # 序列化模型
+    model_bytes = pickle.dumps(model)
+    model_data = {
+        'model_name': model_name,
+        'model': model_bytes,  # 将模型字节流存入数据库
+    }
+    print('Training completed!')
+
+    if mongodb_write_table in db.list_collection_names():
+        db[mongodb_write_table].drop()
+        print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
+    collection = db[mongodb_write_table]  # 集合名称
+    collection.insert_one(model_data)
+    print("model inserted successfully!")
+
+
+def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,args):
+    mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
+                                args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
+    client = MongoClient(mongodb_connection)
+    db = client[mongodb_database]
+    if scaler_table in db.list_collection_names():
+        db[scaler_table].drop()
+        print(f"Collection '{scaler_table} already exist, deleted successfully!")
+    collection = db[scaler_table]  # 集合名称
+    # Save the scalers in MongoDB as binary data
+    collection.insert_one({
+        "feature_scaler": feature_scaler_bytes.read(),
+        "target_scaler": target_scaler_bytes.read()
+    })
+    print("scaler_model inserted successfully!")
+    if model_table in db.list_collection_names():
+        db[model_table].drop()
+        print(f"Collection '{model_table} already exist, deleted successfully!")
+    model_table = db[model_table]
+    # 创建 BytesIO 缓冲区
+    model_buffer = BytesIO()
+    # 将模型保存为 HDF5 格式到内存 (BytesIO)
+    model.save(model_buffer, save_format='h5')
+    # 将指针移到缓冲区的起始位置
+    model_buffer.seek(0)
+    # 获取模型的二进制数据
+    model_data = model_buffer.read()
+    # 将模型保存到 MongoDB
+    model_table.insert_one({
+        "model_name": model_name,
+        "model_data": model_data
+    })
+    print("模型成功保存到 MongoDB!")
+
+def insert_trained_model_into_mongo(model, args):
+    mongodb_connection,mongodb_database,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
+                                args['mongodb_database'],args['model_table'],args['model_name'])
+
+    gen_time, params_json, descr = args['gen_time'], args['params'], args['descr']
+    client = MongoClient(mongodb_connection)
+    db = client[mongodb_database]
+    if model_table in db.list_collection_names():
+        db[model_table].drop()
+        print(f"Collection '{model_table} already exist, deleted successfully!")
+    model_table = db[model_table]
+    # 创建 BytesIO 缓冲区
+    model_buffer = BytesIO()
+    # 将模型保存为 HDF5 格式到内存 (BytesIO)
+    model.save(model_buffer, save_format='h5')
+    # 将指针移到缓冲区的起始位置
+    model_buffer.seek(0)
+    # 获取模型的二进制数据
+    model_data = model_buffer.read()
+    # 将模型保存到 MongoDB
+    model_table.insert_one({
+        "model_name": model_name,
+        "model_data": model_data,
+        "gen_time": gen_time,
+        "params": params_json,
+        "descr": descr
+    })
+    print("模型成功保存到 MongoDB!")
+
+def insert_scaler_model_into_mongo(feature_scaler_bytes, scaled_target_bytes, args):
+    mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
+                                args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
+    gen_time = args['gen_time']
+    client = MongoClient(mongodb_connection)
+    db = client[mongodb_database]
+    if scaler_table in db.list_collection_names():
+        db[scaler_table].drop()
+        print(f"Collection '{scaler_table} already exist, deleted successfully!")
+    collection = db[scaler_table]  # 集合名称
+    # Save the scalers in MongoDB as binary data
+    collection.insert_one({
+        "model_name": model_name,
+        "gent_time": gen_time,
+        "feature_scaler": feature_scaler_bytes.read(),
+        "target_scaler": scaled_target_bytes.read()
+    })
+    print("scaler_model inserted successfully!")
+
+
+def get_h5_model_from_mongo(args, custom=None):
+    mongodb_connection,mongodb_database,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['model_table'],args['model_name']
+    client = MongoClient(mongodb_connection)
+    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
+    db = client[mongodb_database]
+    collection = db[model_table]  # 集合名称
+
+     # 查询 MongoDB 获取模型数据
+    model_doc = collection.find_one({"model_name": model_name}, sort=[('gen_time', DESCENDING)])
+    if model_doc:
+        model_data = model_doc['model_data']  # 获取模型的二进制数据
+        # 将二进制数据加载到 BytesIO 缓冲区
+        model_buffer = BytesIO(model_data)
+        # 从缓冲区加载模型
+         # 使用 h5py 和 BytesIO 从内存中加载模型
+        with h5py.File(model_buffer, 'r') as f:
+            model = tf.keras.models.load_model(f, custom_objects=custom)
+        print(f"{model_name}模型成功从 MongoDB 加载!")
+        client.close()
+        return model
+    else:
+        print(f"未找到model_name为 {model_name} 的模型。")
+        client.close()
+        return None
+
+
+def get_scaler_model_from_mongo(args, only_feature_scaler=False):
+    """
+    根据模 型名称版本 和 生成时间 获取模型
+    """
+    mongodb_connection, mongodb_database, scaler_table, = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'], args['scaler_table'])
+    model_name, gen_time = args['model_name'], args['gent_time']
+    client = MongoClient(mongodb_connection)
+    # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
+    db = client[mongodb_database]
+    collection = db[scaler_table]  # 集合名称
+    # Retrieve the scalers from MongoDB
+    scaler_doc = collection.find_one({"model_name": model_name, "gen_time": gen_time})
+    # Deserialize the scalers
+
+    feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"])
+    feature_scaler = joblib.load(feature_scaler_bytes)
+    if only_feature_scaler:
+        return feature_scaler
+    target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
+    target_scaler = joblib.load(target_scaler_bytes)
+    return feature_scaler,target_scaler